Source code for impyute.util.checks

""" impyute.util.check """
from functools import wraps
import numpy as np
from impyute.util import find_null
from impyute.util import BadInputError
# pylint:disable=invalid-name

[docs]def checks(fn): """ Main check function to ensure input is correctly formatted Parameters ---------- data: numpy.ndarray Data to impute. Returns ------- bool True if `data` is correctly formatted """ @wraps(fn) def wrapper(*args, **kwargs): """ Run input checks""" data = args[0] if len(np.shape(data)) != 2: raise BadInputError("No support for arrays that aren't 2D yet.") elif not _shape_2d(data): raise BadInputError("Not a 2D array.") elif not _is_ndarray(data): raise BadInputError("Not a np.ndarray.") elif not _dtype_float(data): raise BadInputError("Data is not float.") elif not _nan_exists(data): raise BadInputError("No NaN's in given data") return fn(*args, **kwargs) return wrapper
def _shape_2d(data): """ True if array is 2D""" return len(np.shape(data)) == 2 def _shape_3d(data): """ True if array is 3D""" return len(np.shape(data)) == 3 def _is_ndarray(data): """ True if the array is an instance of numpy's ndarray""" return isinstance(data, np.ndarray) def _dtype_float(data): """ True if the values in the array are floating point""" return data.dtype == np.float def _nan_exists(data): """ True if there is at least one np.nan in the array""" null_xy = find_null(data) return len(null_xy) > 0