diff --git a/opnmf/_version.py b/opnmf/_version.py index 36a1ce9..09c9d8d 100644 --- a/opnmf/_version.py +++ b/opnmf/_version.py @@ -1 +1 @@ -__version__ = '0.0.3.dev' +__version__ = '0.0.3.dev1+g88d7273.d20240120' diff --git a/opnmf/model.py b/opnmf/model.py index f57c6bd..8acf7f6 100644 --- a/opnmf/model.py +++ b/opnmf/model.py @@ -123,6 +123,7 @@ def fit_transform(self, X, init_W=None): def transform(self, X): """Transform the data X according to the fitted OPNMF model. + Added functionality by Alfie Wearn 2024-01-20 Parameters ---------- @@ -131,10 +132,25 @@ def transform(self, X): Returns ------- - W : ndarray of shape (n_samples, n_components) + H : ndarray of shape (n_components, n_features) Transformed data. + As the OPNMF is: X~W*H, this calculates a new H given pre-calculated W and a new X. """ - raise NotImplementedError("Don't know how to do this!") + # Ensure the model is fitted + check_is_fitted(self, 'components_') + + # Apply transformation to new subjects + # Use the fixed W (self.coef_) learned during training to transform the new data + _, H, _ = opnmf(X, n_components=self.n_components_, W_fixed=self.coef_, + max_iter=self.max_iter, tol=self.tol) + + # Calculate the reconstruction + X_reconstructed = self.coef_ @ H + + # Calculate MSE + mse = np.linalg.norm(X - (self.coef_ @ H), ord='fro') + + return H, mse def mse(self): check_is_fitted(self) diff --git a/opnmf/opnmf.py b/opnmf/opnmf.py index 8743952..4d766e1 100644 --- a/opnmf/opnmf.py +++ b/opnmf/opnmf.py @@ -6,8 +6,7 @@ from . logging import logger, warn -def opnmf(X, n_components, max_iter=50000, tol=1e-5, init='nndsvd', - init_W=None): +def opnmf(X, n_components, max_iter=50000, tol=1e-5, init='nndsvd', init_W=None, W_fixed=None): """ Orthogonal projective non-negative matrix factorization. @@ -40,6 +39,9 @@ def opnmf(X, n_components, max_iter=50000, tol=1e-5, init='nndsvd', init_W: array (n_samples, n_components) Fixed initial coefficient matrix. + W_fixed: ndarray of shape (n_samples, n_components), default=None + Fixed basis matrix. If provided, the function will only solve for H. + Returns ------- W : ndarray of shape (n_samples, n_components) @@ -49,42 +51,47 @@ def opnmf(X, n_components, max_iter=50000, tol=1e-5, init='nndsvd', mse : float Reconstruction error """ - if init != 'custom': - if init_W is not None: - warn('Initialisation was not set to "custom" but an initial W ' - 'matrix was specified. This matrix will be ignored.') - logger.info(f'Initializing using {init}') - W, _ = _initialize_nmf(X, n_components, init=init) - init_W = None + if W_fixed is not None: + # Use the fixed W and skip the update loop + W = W_fixed else: - W = init_W - delta_W = np.inf - XX = X @ X.T - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - for iter in range(max_iter): - old_W = W - - enum = XX @ W - denom = W @ (W.T @ XX @ W) - W = W * enum / denom - - W[W < 1e-16] = 1e-16 - W = W / np.linalg.norm(W, ord=2) - - delta_W = (np.linalg.norm(old_W - W, ord='fro') / - np.linalg.norm(old_W, ord='fro')) - if (iter % 100) == 0: - obj = np.linalg.norm(X - (W @ (W.T @ X)), ord='fro') - logger.info(f'iter={iter} diff={delta_W}, obj={obj}') - if delta_W < tol: - logger.info(f'Converged in {iter} iterations') - break - - if delta_W > tol: - warn('OPNMF did not converge with ' - f'tolerance = {tol} under {max_iter} iterations') + # Initialization as per the original code + if init != 'custom': + if init_W is not None: + warn('Initialisation was not set to "custom" but an initial W ' + 'matrix was specified. This matrix will be ignored.') + logger.info(f'Initializing using {init}') + W, _ = _initialize_nmf(X, n_components, init=init) + init_W = None + else: + W = init_W + # Main iterative loop for updating W + delta_W = np.inf + XX = X @ X.T + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for iter in range(max_iter): + old_W = W + + enum = XX @ W + denom = W @ (W.T @ XX @ W) + W = W * enum / denom + + W[W < 1e-16] = 1e-16 + W = W / np.linalg.norm(W, ord=2) + + delta_W = (np.linalg.norm(old_W - W, ord='fro') / + np.linalg.norm(old_W, ord='fro')) + if (iter % 100) == 0: + obj = np.linalg.norm(X - (W @ (W.T @ X)), ord='fro') + logger.info(f'iter={iter} diff={delta_W}, obj={obj}') + if delta_W < tol: + logger.info(f'Converged in {iter} iterations') + break + + if delta_W > tol: + warn('OPNMF did not converge with ' + f'tolerance = {tol} under {max_iter} iterations') H = W.T @ X