Skip to content

Code simplifications#74

Merged
AlexanderFengler merged 11 commits intomainfrom
code-simplifications
Mar 19, 2026
Merged

Code simplifications#74
AlexanderFengler merged 11 commits intomainfrom
code-simplifications

Conversation

@AlexanderFengler
Copy link
Copy Markdown
Member

No description provided.

Copilot AI review requested due to automatic review settings March 17, 2026 01:38
@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR performs several code simplifications and adds new functionality to the LANfactory package (version bump to 0.6.0). It renames Jax MLP classes for consistency, removes unused parameters, refactors duplicated class logic, adds convenience helper functions for DataLoader creation, and introduces a new HuggingFace Hub integration module with CLI tools for model upload/download.

Changes:

  • Renames MLPJaxJaxMLP and MLPJaxFactoryJaxMLPFactory; removes generative_model_id parameter and **kwargs from TorchMLP.__init__; refactors LoadTorchMLP/LoadTorchMLPInfer into an inheritance hierarchy
  • Adds make_dataloader, make_train_valid_dataloaders, and TorchMLPFactory helper functions; consolidates network_config_cpn/opn into network_config_choice_prob aliases
  • Introduces lanfactory.hf module with upload_model/download_model functions, model card generation, and corresponding upload-hf/download-hf CLI commands

Reviewed changes

Copilot reviewed 34 out of 34 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
src/lanfactory/trainers/jax_mlp.py Renames MLPJaxJaxMLP and MLPJaxFactoryJaxMLPFactory
src/lanfactory/trainers/torch_mlp.py Removes **kwargs/generative_model_id, adds helper functions, refactors LoadTorchMLP hierarchy
src/lanfactory/trainers/init.py Updates exports with new names and helper functions
src/lanfactory/config/network_configs.py Consolidates CPN/OPN configs with aliases
src/lanfactory/hf/*.py New HuggingFace Hub integration module
src/lanfactory/cli/upload_hf.py New CLI for uploading models to HF Hub
src/lanfactory/cli/download_hf.py New CLI for downloading models from HF Hub
src/lanfactory/onnx/transform_onnx.py Removes generative_model_id=None
tests/test_jax_mlp.py Updates to use new JaxMLP/JaxMLPFactory names
tests/test_torch_mlp.py Removes generative_model_id=None from test calls
tests/test_transform_onnx.py Removes generative_model_id assertion
tests/test_mlflow_integration.py Updates to use new Jax class names
tests/test_end_to_end_*.py Updates to use new class names and removed params
tests/hf/*.py New tests for HF upload/download/model_card
tests/cli/test_hf_cli.py New CLI smoke tests
docs/**/*.ipynb Updates notebooks to use new APIs and helper functions
docs/using_huggingface.md New HF Hub documentation
pyproject.toml Adds hf optional dependency, new CLI entry points, coverage config
mkdocs.yml Adds HF docs to navigation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +11 to 27
from .jax_mlp import JaxMLPFactory, JaxMLP, ModelTrainerJaxMLP

__all__ = [
# Dataset and DataLoader helpers
"DatasetTorch",
"make_dataloader",
"make_train_valid_dataloaders",
# Torch MLP
"TorchMLP",
"TorchMLPFactory",
"ModelTrainerTorchMLP",
"LoadTorchMLPInfer",
"LoadTorchMLP",
"MLPJaxFactory",
"MLPJax",
"LoadTorchMLPInfer",
# Jax MLP
"JaxMLPFactory",
"JaxMLP",
"ModelTrainerJaxMLP",
Comment on lines +30 to +32
# Backward-compatible aliases
network_config_cpn = network_config_choice_prob
network_config_opn = network_config_choice_prob
Comment on lines +53 to +54
hf = ["huggingface-hub>=0.20.0"]
all = ["mlflow>=3.6.0", "huggingface-hub>=0.20.0"]
from lanfactory.hf.upload import upload_model
from lanfactory.hf.download import download_model

# Default repository for official HSSM models
Comment on lines +98 to +102
valid_network_types = ["lan", "cpn", "opn"]
if network_type not in valid_network_types:
raise ValueError(
f"network_type must be one of {valid_network_types}, got: {network_type}"
)
Comment on lines +160 to +346
def make_dataloader(
file_ids: list[str] | list[Path],
batch_size: int,
network_type: str = "lan",
label_lower_bound: float | None = None,
shuffle: bool = True,
num_workers: int = 1,
pin_memory: bool = True,
) -> DataLoader:
"""Create a DataLoader for LAN/CPN/OPN training.

This is a convenience function that creates a DatasetTorch and wraps it
in a PyTorch DataLoader with sensible defaults.

Arguments
---------
file_ids: List of paths to training data pickle files.
batch_size: Batch size for training.
network_type: Type of network ("lan", "cpn", or "opn").
Determines the feature/label keys in the data files.
label_lower_bound: Lower bound for labels. If None and network_type
is "lan", defaults to log(1e-10).
shuffle: Whether to shuffle data (default: True).
num_workers: Number of worker processes for data loading (default: 1).
pin_memory: Whether to pin memory for faster GPU transfer (default: True).

Returns
-------
torch.utils.data.DataLoader configured for training.

Example
-------
>>> file_list = list(Path("data/lan_mlp/ddm").glob("*.pickle"))
>>> train_dl = make_dataloader(
... file_ids=file_list,
... batch_size=4096,
... network_type="lan",
... )
"""
# Set sensible defaults based on network type
if label_lower_bound is None and network_type == "lan":
label_lower_bound = np.log(1e-10)

dataset = DatasetTorch(
file_ids=file_ids,
batch_size=batch_size,
features_key=f"{network_type}_data",
label_key=f"{network_type}_labels",
label_lower_bound=label_lower_bound,
)

return DataLoader(
dataset,
shuffle=shuffle,
batch_size=None,
num_workers=num_workers,
pin_memory=pin_memory,
)


def make_train_valid_dataloaders(
file_ids: list[str] | list[Path],
batch_size: int,
network_type: str = "lan",
train_val_split: float = 0.9,
shuffle_files: bool = True,
label_lower_bound: float | None = None,
num_workers: int = 1,
pin_memory: bool = True,
) -> tuple[DataLoader, DataLoader, int]:
"""Create train and validation DataLoaders with automatic file splitting.

This is a convenience function that splits the file list into train/validation
sets and creates DataLoaders for each.

Arguments
---------
file_ids: List of paths to training data pickle files.
batch_size: Batch size for training.
network_type: Type of network ("lan", "cpn", or "opn").
train_val_split: Fraction of files to use for training (default: 0.9).
shuffle_files: Whether to shuffle files before splitting (default: True).
label_lower_bound: Lower bound for labels. If None and network_type
is "lan", defaults to log(1e-10).
num_workers: Number of worker processes for data loading (default: 1).
pin_memory: Whether to pin memory for faster GPU transfer (default: True).

Returns
-------
tuple of (train_dataloader, valid_dataloader, input_dim)

Example
-------
>>> file_list = list(Path("data/lan_mlp/ddm").glob("*.pickle"))
>>> train_dl, valid_dl, input_dim = make_train_valid_dataloaders(
... file_ids=file_list,
... batch_size=4096,
... network_type="lan",
... train_val_split=0.9,
... )
"""
import random

file_list = [str(f) for f in file_ids] # Ensure strings for consistency
if shuffle_files:
random.shuffle(file_list)

split_idx = int(len(file_list) * train_val_split)
train_files = file_list[:split_idx]
valid_files = file_list[split_idx:]

if len(train_files) == 0:
raise ValueError(
f"No training files after split. Got {len(file_list)} files "
f"with train_val_split={train_val_split}"
)
if len(valid_files) == 0:
raise ValueError(
f"No validation files after split. Got {len(file_list)} files "
f"with train_val_split={train_val_split}"
)

train_dl = make_dataloader(
file_ids=train_files,
batch_size=batch_size,
network_type=network_type,
label_lower_bound=label_lower_bound,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
)

valid_dl = make_dataloader(
file_ids=valid_files,
batch_size=batch_size,
network_type=network_type,
label_lower_bound=label_lower_bound,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
)

return train_dl, valid_dl, train_dl.dataset.input_dim


# --- Factory Functions ---


def TorchMLPFactory(
network_config: dict | str,
input_dim: int,
network_type: str | None = None,
) -> "TorchMLP":
"""Factory function to create a TorchMLP object.

This provides a consistent API with JaxMLPFactory and handles
loading network configs from pickle files.

Arguments
---------
network_config: Dictionary containing the network configuration,
or path to a pickled config file.
input_dim: Input dimension (typically from dataloader.dataset.input_dim).
network_type: Network type ("lan", "cpn", "opn"). If not provided,
will be inferred from train_output_type in network_config.

Returns
-------
TorchMLP instance ready for training.

Example
-------
>>> train_dl, valid_dl, input_dim = make_train_valid_dataloaders(...)
>>> net = TorchMLPFactory(
... network_config=network_config,
... input_dim=input_dim,
... )
"""
if isinstance(network_config, str):
with open(network_config, "rb") as f:
network_config = pickle.load(f)

return TorchMLP(
network_config=network_config,
input_shape=input_dim,
network_type=network_type,
)
Comment on lines +876 to +881
# Load network config from pickle file if string path provided
if isinstance(network_config, str):
with open(network_config, "rb") as f:
self.network_config = pickle.load(f)
elif isinstance(network_config, dict):
self.network_config = network_config
else:
raise ValueError("network config is neither a string nor a dictionary")
self.network_config = network_config
Copilot AI review requested due to automatic review settings March 17, 2026 03:18
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR combines several "code simplifications" with a new HuggingFace Hub integration module. It renames Jax MLP classes (MLPJaxJaxMLP, MLPJaxFactoryJaxMLPFactory) for consistency, removes unused parameters (generative_model_id, **kwargs, train) from TorchMLP, refactors LoadTorchMLP/LoadTorchMLPInfer into an inheritance hierarchy, adds convenience helper functions for DataLoader creation, introduces a new lanfactory.hf module for uploading/downloading models to HuggingFace Hub, and reorganizes config files with proper aliases.

Changes:

  • Renamed Jax MLP classes and removed deprecated parameters from TorchMLP; refactored LoadTorchMLPInfer to inherit from LoadTorchMLP; added TorchMLPFactory, make_dataloader, and make_train_valid_dataloaders helper functions
  • Added new lanfactory.hf module with upload/download functionality, model card generation, and CLI commands (upload-hf, download-hf)
  • Updated all tests, notebooks, and documentation to use new APIs and helper functions

Reviewed changes

Copilot reviewed 34 out of 34 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
src/lanfactory/trainers/init.py Updated exports: new names, added factory/dataloader helpers
src/lanfactory/trainers/jax_mlp.py Renamed MLPJax→JaxMLP, MLPJaxFactory→JaxMLPFactory
src/lanfactory/trainers/torch_mlp.py Removed **kwargs/generative_model_id, added helpers, refactored LoadTorchMLP hierarchy
src/lanfactory/config/network_configs.py Reorganized configs, added backward-compatible aliases
src/lanfactory/hf/init.py New HF module init
src/lanfactory/hf/upload.py New upload functionality
src/lanfactory/hf/download.py New download functionality
src/lanfactory/hf/model_card.py New model card generation
src/lanfactory/cli/upload_hf.py New upload CLI command
src/lanfactory/cli/download_hf.py New download CLI command
src/lanfactory/cli/jax_train.py Updated to use JaxMLPFactory
src/lanfactory/onnx/transform_onnx.py Removed generative_model_id
pyproject.toml Added hf extras, CLI entry points, coverage omits
tests/* Updated all tests for renamed classes and removed params
tests/hf/* New tests for HF module
tests/cli/test_hf_cli.py New CLI smoke tests
docs/* Updated tutorials and added HF documentation
mkdocs.yml Added HF docs to navigation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +12 to +17
from lanfactory.hf.upload import upload_model
from lanfactory.hf.download import download_model

# Default repository for official HSSM models
DEFAULT_REPO_ID = "franklab/HSSM"

Comment on lines +53 to +54
hf = ["huggingface-hub>=0.20.0"]
all = ["mlflow>=3.6.0", "huggingface-hub>=0.20.0"]
make_dataloader,
make_train_valid_dataloaders,
)
from .jax_mlp import JaxMLPFactory, JaxMLP, ModelTrainerJaxMLP
Copilot AI review requested due to automatic review settings March 17, 2026 04:11
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR performs several code simplifications and adds new features: renaming Jax MLP classes from MLPJax/MLPJaxFactory to JaxMLP/JaxMLPFactory, removing unused parameters (generative_model_id, **kwargs) from TorchMLP, refactoring LoadTorchMLP/LoadTorchMLPInfer into an inheritance hierarchy, adding convenience helper functions for DataLoader creation, adding TorchMLPFactory, reorganizing config files, and adding a complete HuggingFace Hub integration module with upload/download CLI commands.

Changes:

  • Renames Jax MLP classes to follow JaxMLP/JaxMLPFactory naming convention (with deprecation aliases), removes unused TorchMLP parameters, and refactors LoadTorchMLPInfer to inherit from LoadTorchMLP
  • Adds make_dataloader, make_train_valid_dataloaders, and TorchMLPFactory helper functions, with corresponding tutorial updates
  • Adds lanfactory.hf module with upload_model/download_model functions, model card generation, and upload-hf/download-hf CLI commands

Reviewed changes

Copilot reviewed 35 out of 35 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
src/lanfactory/trainers/jax_mlp.py Renames MLPJaxJaxMLP, MLPJaxFactoryJaxMLPFactory
src/lanfactory/trainers/torch_mlp.py Removes **kwargs/generative_model_id, adds helpers and factory, refactors LoadTorchMLP hierarchy
src/lanfactory/trainers/init.py Updates exports, adds deprecation aliases for old names
src/lanfactory/config/network_configs.py Reorganizes configs, adds network_config_choice_prob with backward-compatible aliases
src/lanfactory/config/init.py Exports new config names
src/lanfactory/hf/init.py New HuggingFace integration module init
src/lanfactory/hf/model_card.py New model card YAML parsing and README generation
src/lanfactory/hf/upload.py New upload functionality for HuggingFace Hub
src/lanfactory/hf/download.py New download functionality from HuggingFace Hub
src/lanfactory/cli/upload_hf.py New CLI command for uploading models
src/lanfactory/cli/download_hf.py New CLI command for downloading models
src/lanfactory/cli/jax_train.py Updates MLPJaxFactoryJaxMLPFactory reference
src/lanfactory/onnx/transform_onnx.py Removes generative_model_id=None from TorchMLP call
pyproject.toml Adds hf optional dependency, CLI entry points, coverage omissions
tests/test_jax_mlp.py Updates all references to new class names
tests/test_torch_mlp.py Removes generative_model_id=None from test calls
tests/test_transform_onnx.py Removes assertion for removed parameter
tests/test_end_to_end_*.py Updates class name references and removes unused params
tests/test_mlflow_integration.py Updates MLPJaxFactoryJaxMLPFactory
tests/hf/ New test files for upload, download, and model card modules
tests/cli/test_hf_cli.py New CLI smoke tests
docs/ Updated tutorials and new HuggingFace documentation
mkdocs.yml Adds HF docs to navigation
notebooks/ Updates class name references in example notebooks

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +53 to +54
hf = ["huggingface-hub>=0.20.0"]
all = ["mlflow>=3.6.0", "huggingface-hub>=0.20.0"]
Copilot AI review requested due to automatic review settings March 18, 2026 20:20
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR modernizes LANfactory’s training/config/test surface to align with newer ssm-simulators, simplifies model construction/loading APIs, and adds HuggingFace Hub upload/download support (library + CLI), with docs/tests updated accordingly.

Changes:

  • Add HuggingFace Hub integration (lanfactory.hf) plus new upload-hf / download-hf CLI entry points and documentation.
  • Introduce PyTorch DataLoader convenience helpers and a TorchMLPFactory; simplify TorchMLP initialization and ONNX transform usage.
  • Rename JAX MLP identifiers to JaxMLP / JaxMLPFactory (with deprecation aliases at lanfactory.trainers), and update tests/notebooks/docs to match updated ssm-simulators config structure.

Reviewed changes

Copilot reviewed 36 out of 36 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
tests/test_transform_onnx.py Updates assertions for simplified TorchMLP init (removed generative_model_id).
tests/test_torch_mlp.py Updates TorchMLP instantiations to match new signature.
tests/test_mlflow_integration.py Updates JAX factory import/name and adapts data generation config/API to newer ssm-simulators.
tests/test_jax_mlp.py Renames tests/imports to JaxMLP / JaxMLPFactory.
tests/test_end_to_end_torch.py Updates generator config structure + data generator API; removes obsolete train arg.
tests/test_end_to_end_jax.py Updates generator config structure + JAX factory name.
tests/hf/test_upload.py Adds unit tests for HF upload functionality and defaults.
tests/hf/test_model_card.py Adds unit tests for model card YAML loading + README generation.
tests/hf/test_download.py Adds unit tests for HF download behavior, filtering, and defaults.
tests/hf/init.py Adds HF tests package marker.
tests/conftest.py Migrates dummy generator configs to ssms.config.get_default_generator_config and new nested keys.
tests/cli/test_hf_cli.py Adds CLI smoke/validation tests for upload-hf and download-hf.
src/lanfactory/trainers/torch_mlp.py Adds make_dataloader / make_train_valid_dataloaders, adds TorchMLPFactory, refactors loaders (LoadTorchMLP*).
src/lanfactory/trainers/jax_mlp.py Renames MLPJax* to JaxMLP* and updates trainer typing accordingly.
src/lanfactory/trainers/init.py Exposes new helpers/factories; adds deprecation aliases via __getattr__.
src/lanfactory/onnx/transform_onnx.py Removes deprecated generative_model_id arg from TorchMLP construction.
src/lanfactory/hf/upload.py Adds HF upload implementation including file collection and README generation.
src/lanfactory/hf/model_card.py Adds model_card YAML parsing + README frontmatter generation utilities.
src/lanfactory/hf/download.py Adds HF download implementation with include/exclude filtering.
src/lanfactory/hf/init.py Adds HF integration module exports and constants.
src/lanfactory/config/network_configs.py Reorganizes example configs; introduces choice-prob config + backward-compatible aliases.
src/lanfactory/config/init.py Exposes new choice-prob config symbols.
src/lanfactory/cli/upload_hf.py Adds Typer-based upload-hf CLI command.
src/lanfactory/cli/download_hf.py Adds Typer-based download-hf CLI command.
src/lanfactory/cli/jax_train.py Updates JAX factory call site to new name.
pyproject.toml Bumps ssm-simulators, adds Typer + HF extras, registers new CLI scripts, adjusts coverage omit list.
notebooks/test_notebooks/test_jax_network_cpn.ipynb Updates notebook usage to JaxMLPFactory.
notebooks/test_notebooks/test_jax_network.ipynb Updates notebook usage to JaxMLPFactory and traceback text.
notebooks/test_notebooks/load_jax_lan_cpn.ipynb Updates notebook usage to JaxMLPFactory.
mkdocs.yml Adds HF docs + HF API page to nav.
docs/using_huggingface.md Adds user documentation for HF upload/download + model_card.yaml.
docs/basic_tutorial/basic_tutorial_opn_torch.ipynb Switches tutorial to DataLoader helpers and TorchMLPFactory usage.
docs/basic_tutorial/basic_tutorial_lan_torch.ipynb Switches tutorial to DataLoader helpers and TorchMLPFactory usage.
docs/basic_tutorial/basic_tutorial_lan_jax.ipynb Switches tutorial to DataLoader helpers and JaxMLPFactory usage.
docs/basic_tutorial/basic_tutorial_cpn_torch.ipynb Switches tutorial to DataLoader helpers and TorchMLPFactory usage.
docs/api/hf.md Adds MkDocs API stub for lanfactory.hf.
Comments suppressed due to low confidence (1)

src/lanfactory/trainers/jax_mlp.py:43

  • JaxMLPFactory loads pickle configs via pickle.load(open(...)) without closing the file handle. Use a context manager (with open(...) as f:) to avoid leaking file descriptors and to match how other loaders in the repo handle pickles.
    if isinstance(network_config, str):
        network_config_internal = pickle.load(open(network_config, "rb"))
    elif isinstance(network_config, dict):
        network_config_internal = network_config

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +186 to +187
# Construct URL
url = f"https://huggingface.co/{repo_id}/tree/main/{path_in_repo}"
Comment on lines 31 to 56
@@ -38,6 +38,7 @@ dependencies = [
"frozendict>=2.4.6",
"onnx>=1.17.0",
"matplotlib>=3.10.1",
"typer>=0.9.0",
]

keywords = [
@@ -50,6 +51,8 @@ keywords = [

[project.optional-dependencies]
mlflow = ["mlflow>=3.6.0"]
hf = ["huggingface-hub>=0.20.0"]
all = ["mlflow>=3.6.0", "huggingface-hub>=0.20.0"]

Comment on lines +27 to +28
def JaxMLPFactory(network_config: dict | str = {}, train: bool = True) -> "JaxMLP":
"""Factory function to create a JaxMLP object.
Comment on lines 27 to 58
@@ -34,7 +34,7 @@ def MLPJaxFactory(network_config: dict | str = {}, train: bool = True) -> "MLPJa
Whether the model should be trained or not.
Returns
-------
MLPJax class initialized with the correct network configuration.
JaxMLP class initialized with the correct network configuration.
"""

if isinstance(network_config, str):
@@ -46,15 +46,15 @@ def MLPJaxFactory(network_config: dict | str = {}, train: bool = True) -> "MLPJa
"network_config argument is not passed as either a dictionary or a string (path to a file)!"
)

return MLPJax(
return JaxMLP(
layer_sizes=network_config_internal["layer_sizes"],
activations=network_config_internal["activations"],
train_output_type=network_config_internal["train_output_type"],
train=train,
)


class MLPJax(nn.Module):
class JaxMLP(nn.Module):
"""JaxMLP class.
batch_size=batch_size,
network_type=network_type,
label_lower_bound=label_lower_bound,
shuffle=True,
Comment on lines +840 to 845
class LoadTorchMLP:
"""General-purpose class to load TorchMLP models.

Does NOT call eval() by default - suitable for fine-tuning or further training.
For inference with eval() enabled, use LoadTorchMLPInfer instead.

Comment on lines +78 to +86
with open(yaml_path, "r") as f:
data = yaml.safe_load(f)

# Extract fields with defaults
config = ModelCardConfig(
tags=data.get("tags", ["lan", "ssm", "hssm"]),
library_name=data.get("library_name", "onnx"),
license=data.get("license", "mit"),
title=data.get("title", "LAN Model"),
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 18, 2026

Codecov Report

❌ Patch coverage is 98.90110% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/lanfactory/trainers/jax_mlp.py 71.42% 1 Missing and 1 partial ⚠️
Files with missing lines Coverage Δ
src/lanfactory/config/network_configs.py 100.00% <100.00%> (ø)
src/lanfactory/hf/download.py 100.00% <100.00%> (ø)
src/lanfactory/hf/model_card.py 100.00% <100.00%> (ø)
src/lanfactory/hf/upload.py 100.00% <100.00%> (ø)
src/lanfactory/onnx/transform_onnx.py 100.00% <ø> (ø)
src/lanfactory/trainers/torch_mlp.py 94.15% <100.00%> (+2.17%) ⬆️
src/lanfactory/trainers/jax_mlp.py 93.03% <71.42%> (-0.91%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copilot AI review requested due to automatic review settings March 19, 2026 18:57
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR modernizes LANfactory’s training utilities and external integrations by adding HuggingFace Hub upload/download support, introducing Torch DataLoader helper/factory functions, and renaming the JAX MLP API to JaxMLP* while updating tests/docs accordingly.

Changes:

  • Added lanfactory.hf module plus new upload-hf / download-hf CLIs and documentation.
  • Added PyTorch helper functions (make_dataloader, make_train_valid_dataloaders) and a TorchMLPFactory, and refactored Torch model loaders.
  • Renamed JAX trainer API from MLPJax* to JaxMLP* (with deprecation aliases at lanfactory.trainers), and updated tests/tutorial notebooks for the newer ssm-simulators generator config shape.

Reviewed changes

Copilot reviewed 36 out of 36 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tests/test_transform_onnx.py Updates ONNX transform test for removed TorchMLP kwargs.
tests/test_torch_mlp.py Removes deprecated TorchMLP args; adds tests for new dataloader/factory helpers.
tests/test_mlflow_integration.py Updates generator config schema and JAX factory rename in MLflow integration tests.
tests/test_jax_mlp.py Renames imports/usages to JaxMLPFactory/JaxMLP.
tests/test_end_to_end_torch.py Updates ssms generator usage/config schema; removes unsupported TorchMLP arg.
tests/test_end_to_end_jax.py Updates ssms generator usage/config schema; uses JaxMLPFactory.
tests/hf/test_upload.py Adds unit tests for HF upload utilities.
tests/hf/test_model_card.py Adds unit tests for model card YAML loading + README generation.
tests/hf/test_download.py Adds unit tests for HF download utilities.
tests/hf/init.py Adds test package marker for HF tests.
tests/conftest.py Switches to ssms.config.get_default_generator_config and new nested config keys.
tests/cli/test_hf_cli.py Adds subprocess-based smoke tests for new HF CLIs.
src/lanfactory/trainers/torch_mlp.py Adds dataloader helpers + Torch factory; refactors model loaders and TorchMLP signature.
src/lanfactory/trainers/jax_mlp.py Renames MLPJax* to JaxMLP* and updates trainer typing/docs accordingly.
src/lanfactory/trainers/init.py Re-exports new helpers/factories; provides deprecation aliases via __getattr__.
src/lanfactory/onnx/transform_onnx.py Removes deprecated TorchMLP kwarg when exporting to ONNX.
src/lanfactory/hf/upload.py New HF upload implementation (file collection, dry-run, repo upload).
src/lanfactory/hf/model_card.py New model card config + README generation utilities.
src/lanfactory/hf/download.py New HF download implementation (list/filter/download files).
src/lanfactory/hf/init.py New HF package exports/constants.
src/lanfactory/config/network_configs.py Reorganizes config examples; adds choice-prob configs and compatibility aliases.
src/lanfactory/config/init.py Exposes new config names via package exports.
src/lanfactory/cli/upload_hf.py New Typer CLI for uploading models to HF hub.
src/lanfactory/cli/download_hf.py New Typer CLI for downloading models from HF hub.
src/lanfactory/cli/jax_train.py Updates to use JaxMLPFactory.
pyproject.toml Updates dependencies (ssm-simulators bump, typer add, hf extras) and adds console scripts.
mkdocs.yml Adds HF guide/API docs to navigation.
docs/using_huggingface.md New user guide for HF upload/download and model_card.yaml.
docs/api/hf.md Adds API doc stub for lanfactory.hf.
notebooks/test_notebooks/test_jax_network_cpn.ipynb Updates notebook to JaxMLPFactory.
notebooks/test_notebooks/test_jax_network.ipynb Updates notebook to JaxMLPFactory.
notebooks/test_notebooks/load_jax_lan_cpn.ipynb Updates notebook to JaxMLPFactory.
docs/basic_tutorial/basic_tutorial_opn_torch.ipynb Refactors tutorial to use new dataloader helpers and Torch factory.
docs/basic_tutorial/basic_tutorial_lan_torch.ipynb Refactors tutorial to use new dataloader helpers and Torch factory.
docs/basic_tutorial/basic_tutorial_lan_jax.ipynb Refactors tutorial to use new dataloader helpers and JaxMLPFactory.
docs/basic_tutorial/basic_tutorial_cpn_torch.ipynb Refactors tutorial to use new dataloader helpers and Torch factory.
Comments suppressed due to low confidence (1)

src/lanfactory/trainers/jax_mlp.py:43

  • File loading in JaxMLPFactory uses pickle.load(open(...)) without a context manager, which can leak file handles if an exception occurs. Use with open(network_config, "rb") as f: ... for deterministic cleanup.
    if isinstance(network_config, str):
        network_config_internal = pickle.load(open(network_config, "rb"))
    elif isinstance(network_config, dict):
        network_config_internal = network_config

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +10 to +16
from lanfactory.hf.model_card import ( # noqa: E402
load_model_card_yaml,
generate_readme,
ModelCardConfig,
)
from lanfactory.hf.upload import upload_model # noqa: E402
from lanfactory.hf.download import download_model # noqa: E402
"frozendict>=2.4.6",
"onnx>=1.17.0",
"matplotlib>=3.10.1",
"typer>=0.9.0",
app = typer.Typer()


@app.command()
Comment on lines +18 to +28
app = typer.Typer()


@app.command()
def main(
network_type: str = typer.Option(
...,
"--network-type",
help="Network type: lan, cpn, or opn.",
),
model_name: str = typer.Option(
Comment on lines +12 to +13
from lanfactory.hf import DEFAULT_REPO_ID, VALID_NETWORK_TYPES

Comment on lines +11 to +12
from lanfactory.hf import DEFAULT_REPO_ID, VALID_NETWORK_TYPES

Comment on lines 27 to 30
def JaxMLPFactory(network_config: dict | str = {}, train: bool = True) -> "JaxMLP":
"""Factory function to create a JaxMLP object.
Arguments
---------
Comment on lines +30 to 33
# Backward-compatible aliases
network_config_cpn = network_config_choice_prob
network_config_opn = network_config_choice_prob

@AlexanderFengler AlexanderFengler merged commit 58cd6f9 into main Mar 19, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants