diff --git a/fancyimpute/solver.py b/fancyimpute/solver.py index 2f966a7..54e8e66 100644 --- a/fancyimpute/solver.py +++ b/fancyimpute/solver.py @@ -182,17 +182,21 @@ def fit_transform(self, X, y=None): "Expected %s.fill() to return NumPy array but got %s" % ( self.__class__.__name__, type(X_filled))) - - X_result = self.solve(X_filled, missing_mask) - if not isinstance(X_result, np.ndarray): - raise TypeError( - "Expected %s.solve() to return NumPy array but got %s" % ( - self.__class__.__name__, - type(X_result))) - - X_result = self.project_result(X=X_result) - X_result[observed_mask] = X_original[observed_mask] - return X_result + # check if there is any missing data + if ((missing_mask == True).any()): + X_result = self.solve(X_filled, missing_mask) + + if not isinstance(X_result, np.ndarray): + raise TypeError( + "Expected %s.solve() to return NumPy array but got %s" % ( + self.__class__.__name__, + type(X_result))) + + X_result = self.project_result(X=X_result) + X_result[observed_mask] = X_original[observed_mask] + return X_result + else: + return X_filled def fit(self, X, y=None): """ diff --git a/test/test_soft_impute.py b/test/test_soft_impute.py index 9d14bbf..adb5fd8 100644 --- a/test/test_soft_impute.py +++ b/test/test_soft_impute.py @@ -1,8 +1,9 @@ from low_rank_data import XY, XY_incomplete, missing_mask from common import reconstruction_error - +import numpy as np from fancyimpute import SoftImpute + def test_soft_impute_with_low_rank_random_matrix(): solver = SoftImpute() XY_completed = solver.fit_transform(XY_incomplete) @@ -13,5 +14,14 @@ def test_soft_impute_with_low_rank_random_matrix(): name="SoftImpute") assert missing_mae < 0.1, "Error too high!" + +# test if the solver for a submodel is running for a numpy array without any missing data +def check_for_no_missing_data(): + X = np.ones((5, 5)) + Xf = SoftImpute().fit_transform(X) + assert (Xf.all() == X.all()) + + if __name__ == "__main__": test_soft_impute_with_low_rank_random_matrix() + check_for_no_missing_data()