diff --git a/netZooPy/cobra/cobra.py b/netZooPy/cobra/cobra.py index 775a9250..c5377a6b 100644 --- a/netZooPy/cobra/cobra.py +++ b/netZooPy/cobra/cobra.py @@ -1,9 +1,10 @@ import pandas as pd import numpy as np from scipy.linalg import eigh,pinv +from sklearn.linear_model import LinearRegression, Lasso -def cobra(X, expression): +def cobra(X, expression, cobra='nnls', alpha: np.float64=0.1): """ COBRA decomposes a (partial) gene co-expression matrix as a linear combination of covariate-specific components. @@ -16,6 +17,12 @@ def cobra(X, expression): design matrix of size (n, q), n = number of samples, q = number of covariates expression : np.ndarray, pd.DataFrame gene expression as a matrix of size (g, n), g = number of genes + cobra : string + regression mode + nnls: Non-negative least square (default) + nnlasso: Non-negative LASSO + deprecated: MLE estimate + alpha : Returns --------- psi : array @@ -55,26 +62,32 @@ def cobra(X, expression): Q = c_eigenvectors[:, indices_nonzero].T[::-1].T # - gtq = np.matmul(g.T, Q) d = c_eigenvalues[indices_nonzero][::-1] - # - xtx_inv = np.linalg.pinv( - np.dot(X.T, X) - ) - xtx_inv_xt = np.dot( - xtx_inv, X.T - ) + if cobra=='nnls': + model = LinearRegression(positive=True).fit(X, np.diag(d) ) + psi = np.transpose(model.coef_) + elif cobra=='nnlasso': + model == Lasso(alpha=alpha, positive=True).fit(X, np.diag(d) ) + psi = np.transpose(model.coef_) + elif cobra=='deprecated': + gtq = np.matmul(g.T, Q) + xtx_inv = np.linalg.pinv( + np.dot(X.T, X) + ) + xtx_inv_xt = np.dot( + xtx_inv, X.T + ) - # - psi = np.zeros((q, n)) + # + psi = np.zeros((q, n)) - for i in range(q): - for h in range(n): - psi[i, h] = n * np.sum([ - ( - xtx_inv_xt[i, k] * gtq[k, h] ** 2 - ) for k in range(n) - ]) + for i in range(q): + for h in range(n): + psi[i, h] = n * np.sum([ + ( + xtx_inv_xt[i, k] * gtq[k, h] ** 2 + ) for k in range(n) + ]) return psi, Q, d, g diff --git a/requirements.txt b/requirements.txt index 1f8151d5..05be41d8 100755 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ igraph joblib statsmodels click +scikit-learn diff --git a/tests/test_cobra.py b/tests/test_cobra.py index 2fe940fc..49c3d5f5 100644 --- a/tests/test_cobra.py +++ b/tests/test_cobra.py @@ -23,7 +23,7 @@ def test_cobra(): G_gt = pd.read_csv(G_gt_path, index_col=0) # Call COBRA - psi, Q, D, G = cobra.cobra(X, expression) + psi, Q, D, G = cobra.cobra(X, expression, cobra='deprecated') # Cast output to pandas psi = pd.DataFrame(psi, index=psi_gt.index, columns=psi_gt.columns)