You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Factory function. Takes the config['model'] dict and returns a BilinearImageClassifier.
decomposition.py
Symbol
Description
build_bilinear_tensor(model)
Folds the bilinear block and classifier head into one tensor. Returns [d_output, d_hidden, d_hidden]. Single-layer models only.
symmetrize_bilinear_tensor(T)
Returns (T + T.mT) / 2. Same shape.
decompose_bilinear_model(model)
Main entry point. Runs the full pipeline and returns a DecompositionArtifacts dataclass.
project_eigenvectors_to_input(eigvecs, embed_w)
Maps eigenvectors from hidden space back to input pixel space.
DecompositionArtifacts fields (all Tensors):
Field
Shape (MNIST example)
Description
bilinear_tensor
[10, 256, 256]
Raw Q_c matrices before symmetrization.
symmetrized_tensor
[10, 256, 256]
(Q + Qᵀ)/2 — symmetric matrices for eigendecomposition.
eigenvalues
[10, 256]
Eigenvalues per class, ascending order. Last = largest.
eigenvectors_hidden
[10, 256, 256]
Eigenvectors in hidden space. Column [:, :, k] pairs with eigenvalues[:, k].
eigenvectors_input
[10, 256, 784]
Eigenvectors projected to input pixel space. Reshape [784] → [28, 28] to visualize.
data.py
Symbol
Description
build_image_dataloaders(dataset_cfg, train_cfg)
Returns a DatasetBundle(train_loader, test_loader, input_shape, num_classes). Currently supports mnist and fashion_mnist.
train.py
Symbol
Description
train_image_experiment(config)
Runs the full training loop. Returns a dict with paths: run_dir, metrics_path, best_checkpoint_path, latest_checkpoint_path.
evaluate(model, loader, criterion, device)
Returns (avg_loss, accuracy). Already wrapped in @torch.no_grad().
utils.py
Symbol
Description
load_checkpoint(path, map_location='cpu')
Loads a .pt checkpoint. Returns dict with keys: model_state_dict, config, epoch, metrics.
resolve_device(name)
Resolves 'auto' to cuda / mps / cpu.
set_seed(seed)
Seeds Python, NumPy, and PyTorch for reproducibility.
Checkpoint format
{
'model_state_dict': OrderedDict, # pass to model.load_state_dict()'config': dict, # the full YAML config used for training'epoch': int, # epoch number when saved'metrics': dict, # {'epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc', 'lr'}
}