Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 94 additions & 2 deletions contextualized/easy/ContextualizedNetworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, **kwargs):
)

def predict_correlation(
self, C: np.ndarray, individual_preds: bool = True, squared: bool = True
self, C: np.ndarray, individual_preds: bool = False, squared: bool = True
) -> Union[np.ndarray, List[np.ndarray]]:
"""Predicts context-specific correlations between features.

Expand Down Expand Up @@ -182,7 +182,7 @@ def __init__(self, **kwargs):
super().__init__(ContextualizedMarkovGraph, [], [], MarkovTrainer, **kwargs)

def predict_precisions(
self, C: np.ndarray, individual_preds: bool = True
self, C: np.ndarray, individual_preds: bool = False
) -> Union[np.ndarray, List[np.ndarray]]:
"""Predicts context-specific precision matrices.
Can be converted to context-specific Markov networks by binarizing the networks and setting all non-zero entries to 1.
Expand Down Expand Up @@ -434,6 +434,98 @@ def predict_networks(
)
return betas

def _reconstruct_from_betas(
self, betas: np.ndarray, X_arr: np.ndarray
) -> np.ndarray:

"""Reconstructs features from predicted betas.

Args:
betas (np.ndarray): Coefficient matrices, shape (F, F) or (N, F, F).
X_arr (np.ndarray): Input data, shape (N, F).

Returns:
np.ndarray: Reconstructed data, shape (N, F).
"""

n_samples, n_features = X_arr.shape

B = np.array(betas, copy=True)
if B.ndim == 2:
B = np.broadcast_to(
B[None, :, :], (n_samples, n_features, n_features)
).copy()
elif B.ndim != 3:
raise ValueError(f"Expected betas 2D or 3D, got shape {B.shape}")

# zero diagonal
idx = np.arange(n_features)
B[:, idx, idx] = 0.0

X_hat = dag_pred_np(X_arr, B)
return X_hat

def predict(
self,
C: np.ndarray,
X: np.ndarray,
project_to_dag: bool = True,
individual_preds: bool = False,
**kwargs,
) -> np.ndarray:

"""Predicts reconstructed data from context and features.

Args:
C (np.ndarray): Contextual features, shape (N, K).
X (np.ndarray): Input data, shape (N, F).
project_to_dag (bool, optional): If True, enforce DAG structure. Defaults to True.
individual_preds (bool, optional): If True, return per-bootstrap predictions. Defaults to False.
**kwargs: Additional keyword arguments.

Returns:
np.ndarray: Reconstructed predictions, shape (N, F), or (B, N, F) if individual_preds is True.
"""
X_scaled = self._maybe_scale_X(X)

betas = self.predict_networks(
C,
project_to_dag=project_to_dag,
individual_preds=individual_preds,
**kwargs,
)

# unify iterable over bootstraps
is_bootstrap_stack = isinstance(betas, np.ndarray) and betas.ndim == 4
if isinstance(betas, list) or is_bootstrap_stack:
if is_bootstrap_stack:
betas_iter = (betas[k] for k in range(betas.shape[0]))
else:
betas_iter = betas

reconstructions = [
self._reconstruct_from_betas(b, X_scaled) for b in betas_iter
]
recon_stack = np.stack(reconstructions, axis=0) # (B, N, F)

if self.normalize and self.scalers["X"] is not None:
recon_stack = np.stack(
[
self.scalers["X"].inverse_transform(recon_stack[k])
for k in range(recon_stack.shape[0])
],
axis=0,
)

if individual_preds:
return recon_stack # (B, N, F)
return self._nanrobust_mean(recon_stack, axis=0) # (N, F)

reconstructed_scaled = self._reconstruct_from_betas(betas, X_scaled)
if self.normalize and self.scalers["X"] is not None:
return self.scalers["X"].inverse_transform(reconstructed_scaled)
return reconstructed_scaled

def measure_mses(
self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs
) -> Union[np.ndarray, List[np.ndarray]]:
Expand Down
6 changes: 4 additions & 2 deletions contextualized/easy/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,13 @@ def test_correlation(self):
encoder_type="ngam", num_archetypes=16
)
self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
rho = model.predict_correlation(self.C, squared=False)
rho = model.predict_correlation(self.C, individual_preds=True, squared=False)
assert rho.shape == (1, self.n_samples, self.x_dim, self.x_dim)
rho = model.predict_correlation(self.C, individual_preds=False, squared=False)
assert rho.shape == (self.n_samples, self.x_dim, self.x_dim), rho.shape
rho_squared = model.predict_correlation(self.C, squared=True)
rho_squared = model.predict_correlation(
self.C, individual_preds=True, squared=True
)
assert np.min(rho_squared) >= 0
assert rho_squared.shape == (1, self.n_samples, self.x_dim, self.x_dim)

Expand Down