Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion opnmf/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.3.dev'
__version__ = '0.0.3.dev1+g88d7273.d20240120'
20 changes: 18 additions & 2 deletions opnmf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def fit_transform(self, X, init_W=None):

def transform(self, X):
"""Transform the data X according to the fitted OPNMF model.
Added functionality by Alfie Wearn 2024-01-20

Parameters
----------
Expand All @@ -131,10 +132,25 @@ def transform(self, X):

Returns
-------
W : ndarray of shape (n_samples, n_components)
H : ndarray of shape (n_components, n_features)
Transformed data.
As the OPNMF is: X~W*H, this calculates a new H given pre-calculated W and a new X.
"""
raise NotImplementedError("Don't know how to do this!")
# Ensure the model is fitted
check_is_fitted(self, 'components_')

# Apply transformation to new subjects
# Use the fixed W (self.coef_) learned during training to transform the new data
_, H, _ = opnmf(X, n_components=self.n_components_, W_fixed=self.coef_,
max_iter=self.max_iter, tol=self.tol)

# Calculate the reconstruction
X_reconstructed = self.coef_ @ H

# Calculate MSE
mse = np.linalg.norm(X - (self.coef_ @ H), ord='fro')

return H, mse

def mse(self):
check_is_fitted(self)
Expand Down
81 changes: 44 additions & 37 deletions opnmf/opnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from . logging import logger, warn


def opnmf(X, n_components, max_iter=50000, tol=1e-5, init='nndsvd',
init_W=None):
def opnmf(X, n_components, max_iter=50000, tol=1e-5, init='nndsvd', init_W=None, W_fixed=None):
"""
Orthogonal projective non-negative matrix factorization.

Expand Down Expand Up @@ -40,6 +39,9 @@ def opnmf(X, n_components, max_iter=50000, tol=1e-5, init='nndsvd',
init_W: array (n_samples, n_components)
Fixed initial coefficient matrix.

W_fixed: ndarray of shape (n_samples, n_components), default=None
Fixed basis matrix. If provided, the function will only solve for H.

Returns
-------
W : ndarray of shape (n_samples, n_components)
Expand All @@ -49,42 +51,47 @@ def opnmf(X, n_components, max_iter=50000, tol=1e-5, init='nndsvd',
mse : float
Reconstruction error
"""
if init != 'custom':
if init_W is not None:
warn('Initialisation was not set to "custom" but an initial W '
'matrix was specified. This matrix will be ignored.')
logger.info(f'Initializing using {init}')
W, _ = _initialize_nmf(X, n_components, init=init)
init_W = None
if W_fixed is not None:
# Use the fixed W and skip the update loop
W = W_fixed
else:
W = init_W
delta_W = np.inf
XX = X @ X.T

with warnings.catch_warnings():
warnings.simplefilter("ignore")
for iter in range(max_iter):
old_W = W

enum = XX @ W
denom = W @ (W.T @ XX @ W)
W = W * enum / denom

W[W < 1e-16] = 1e-16
W = W / np.linalg.norm(W, ord=2)

delta_W = (np.linalg.norm(old_W - W, ord='fro') /
np.linalg.norm(old_W, ord='fro'))
if (iter % 100) == 0:
obj = np.linalg.norm(X - (W @ (W.T @ X)), ord='fro')
logger.info(f'iter={iter} diff={delta_W}, obj={obj}')
if delta_W < tol:
logger.info(f'Converged in {iter} iterations')
break

if delta_W > tol:
warn('OPNMF did not converge with '
f'tolerance = {tol} under {max_iter} iterations')
# Initialization as per the original code
if init != 'custom':
if init_W is not None:
warn('Initialisation was not set to "custom" but an initial W '
'matrix was specified. This matrix will be ignored.')
logger.info(f'Initializing using {init}')
W, _ = _initialize_nmf(X, n_components, init=init)
init_W = None
else:
W = init_W
# Main iterative loop for updating W
delta_W = np.inf
XX = X @ X.T
with warnings.catch_warnings():
warnings.simplefilter("ignore")
for iter in range(max_iter):
old_W = W

enum = XX @ W
denom = W @ (W.T @ XX @ W)
W = W * enum / denom

W[W < 1e-16] = 1e-16
W = W / np.linalg.norm(W, ord=2)

delta_W = (np.linalg.norm(old_W - W, ord='fro') /
np.linalg.norm(old_W, ord='fro'))
if (iter % 100) == 0:
obj = np.linalg.norm(X - (W @ (W.T @ X)), ord='fro')
logger.info(f'iter={iter} diff={delta_W}, obj={obj}')
if delta_W < tol:
logger.info(f'Converged in {iter} iterations')
break

if delta_W > tol:
warn('OPNMF did not converge with '
f'tolerance = {tol} under {max_iter} iterations')

H = W.T @ X

Expand Down