Conversation
…n of naming conventions
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Pull request overview
This PR performs several code simplifications and adds new functionality to the LANfactory package (version bump to 0.6.0). It renames Jax MLP classes for consistency, removes unused parameters, refactors duplicated class logic, adds convenience helper functions for DataLoader creation, and introduces a new HuggingFace Hub integration module with CLI tools for model upload/download.
Changes:
- Renames
MLPJax→JaxMLPandMLPJaxFactory→JaxMLPFactory; removesgenerative_model_idparameter and**kwargsfromTorchMLP.__init__; refactorsLoadTorchMLP/LoadTorchMLPInferinto an inheritance hierarchy - Adds
make_dataloader,make_train_valid_dataloaders, andTorchMLPFactoryhelper functions; consolidatesnetwork_config_cpn/opnintonetwork_config_choice_probaliases - Introduces
lanfactory.hfmodule withupload_model/download_modelfunctions, model card generation, and correspondingupload-hf/download-hfCLI commands
Reviewed changes
Copilot reviewed 34 out of 34 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| src/lanfactory/trainers/jax_mlp.py | Renames MLPJax→JaxMLP and MLPJaxFactory→JaxMLPFactory |
| src/lanfactory/trainers/torch_mlp.py | Removes **kwargs/generative_model_id, adds helper functions, refactors LoadTorchMLP hierarchy |
| src/lanfactory/trainers/init.py | Updates exports with new names and helper functions |
| src/lanfactory/config/network_configs.py | Consolidates CPN/OPN configs with aliases |
| src/lanfactory/hf/*.py | New HuggingFace Hub integration module |
| src/lanfactory/cli/upload_hf.py | New CLI for uploading models to HF Hub |
| src/lanfactory/cli/download_hf.py | New CLI for downloading models from HF Hub |
| src/lanfactory/onnx/transform_onnx.py | Removes generative_model_id=None |
| tests/test_jax_mlp.py | Updates to use new JaxMLP/JaxMLPFactory names |
| tests/test_torch_mlp.py | Removes generative_model_id=None from test calls |
| tests/test_transform_onnx.py | Removes generative_model_id assertion |
| tests/test_mlflow_integration.py | Updates to use new Jax class names |
| tests/test_end_to_end_*.py | Updates to use new class names and removed params |
| tests/hf/*.py | New tests for HF upload/download/model_card |
| tests/cli/test_hf_cli.py | New CLI smoke tests |
| docs/**/*.ipynb | Updates notebooks to use new APIs and helper functions |
| docs/using_huggingface.md | New HF Hub documentation |
| pyproject.toml | Adds hf optional dependency, new CLI entry points, coverage config |
| mkdocs.yml | Adds HF docs to navigation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| from .jax_mlp import JaxMLPFactory, JaxMLP, ModelTrainerJaxMLP | ||
|
|
||
| __all__ = [ | ||
| # Dataset and DataLoader helpers | ||
| "DatasetTorch", | ||
| "make_dataloader", | ||
| "make_train_valid_dataloaders", | ||
| # Torch MLP | ||
| "TorchMLP", | ||
| "TorchMLPFactory", | ||
| "ModelTrainerTorchMLP", | ||
| "LoadTorchMLPInfer", | ||
| "LoadTorchMLP", | ||
| "MLPJaxFactory", | ||
| "MLPJax", | ||
| "LoadTorchMLPInfer", | ||
| # Jax MLP | ||
| "JaxMLPFactory", | ||
| "JaxMLP", | ||
| "ModelTrainerJaxMLP", |
| # Backward-compatible aliases | ||
| network_config_cpn = network_config_choice_prob | ||
| network_config_opn = network_config_choice_prob |
| hf = ["huggingface-hub>=0.20.0"] | ||
| all = ["mlflow>=3.6.0", "huggingface-hub>=0.20.0"] |
src/lanfactory/hf/__init__.py
Outdated
| from lanfactory.hf.upload import upload_model | ||
| from lanfactory.hf.download import download_model | ||
|
|
||
| # Default repository for official HSSM models |
src/lanfactory/hf/upload.py
Outdated
| valid_network_types = ["lan", "cpn", "opn"] | ||
| if network_type not in valid_network_types: | ||
| raise ValueError( | ||
| f"network_type must be one of {valid_network_types}, got: {network_type}" | ||
| ) |
| def make_dataloader( | ||
| file_ids: list[str] | list[Path], | ||
| batch_size: int, | ||
| network_type: str = "lan", | ||
| label_lower_bound: float | None = None, | ||
| shuffle: bool = True, | ||
| num_workers: int = 1, | ||
| pin_memory: bool = True, | ||
| ) -> DataLoader: | ||
| """Create a DataLoader for LAN/CPN/OPN training. | ||
|
|
||
| This is a convenience function that creates a DatasetTorch and wraps it | ||
| in a PyTorch DataLoader with sensible defaults. | ||
|
|
||
| Arguments | ||
| --------- | ||
| file_ids: List of paths to training data pickle files. | ||
| batch_size: Batch size for training. | ||
| network_type: Type of network ("lan", "cpn", or "opn"). | ||
| Determines the feature/label keys in the data files. | ||
| label_lower_bound: Lower bound for labels. If None and network_type | ||
| is "lan", defaults to log(1e-10). | ||
| shuffle: Whether to shuffle data (default: True). | ||
| num_workers: Number of worker processes for data loading (default: 1). | ||
| pin_memory: Whether to pin memory for faster GPU transfer (default: True). | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.utils.data.DataLoader configured for training. | ||
|
|
||
| Example | ||
| ------- | ||
| >>> file_list = list(Path("data/lan_mlp/ddm").glob("*.pickle")) | ||
| >>> train_dl = make_dataloader( | ||
| ... file_ids=file_list, | ||
| ... batch_size=4096, | ||
| ... network_type="lan", | ||
| ... ) | ||
| """ | ||
| # Set sensible defaults based on network type | ||
| if label_lower_bound is None and network_type == "lan": | ||
| label_lower_bound = np.log(1e-10) | ||
|
|
||
| dataset = DatasetTorch( | ||
| file_ids=file_ids, | ||
| batch_size=batch_size, | ||
| features_key=f"{network_type}_data", | ||
| label_key=f"{network_type}_labels", | ||
| label_lower_bound=label_lower_bound, | ||
| ) | ||
|
|
||
| return DataLoader( | ||
| dataset, | ||
| shuffle=shuffle, | ||
| batch_size=None, | ||
| num_workers=num_workers, | ||
| pin_memory=pin_memory, | ||
| ) | ||
|
|
||
|
|
||
| def make_train_valid_dataloaders( | ||
| file_ids: list[str] | list[Path], | ||
| batch_size: int, | ||
| network_type: str = "lan", | ||
| train_val_split: float = 0.9, | ||
| shuffle_files: bool = True, | ||
| label_lower_bound: float | None = None, | ||
| num_workers: int = 1, | ||
| pin_memory: bool = True, | ||
| ) -> tuple[DataLoader, DataLoader, int]: | ||
| """Create train and validation DataLoaders with automatic file splitting. | ||
|
|
||
| This is a convenience function that splits the file list into train/validation | ||
| sets and creates DataLoaders for each. | ||
|
|
||
| Arguments | ||
| --------- | ||
| file_ids: List of paths to training data pickle files. | ||
| batch_size: Batch size for training. | ||
| network_type: Type of network ("lan", "cpn", or "opn"). | ||
| train_val_split: Fraction of files to use for training (default: 0.9). | ||
| shuffle_files: Whether to shuffle files before splitting (default: True). | ||
| label_lower_bound: Lower bound for labels. If None and network_type | ||
| is "lan", defaults to log(1e-10). | ||
| num_workers: Number of worker processes for data loading (default: 1). | ||
| pin_memory: Whether to pin memory for faster GPU transfer (default: True). | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple of (train_dataloader, valid_dataloader, input_dim) | ||
|
|
||
| Example | ||
| ------- | ||
| >>> file_list = list(Path("data/lan_mlp/ddm").glob("*.pickle")) | ||
| >>> train_dl, valid_dl, input_dim = make_train_valid_dataloaders( | ||
| ... file_ids=file_list, | ||
| ... batch_size=4096, | ||
| ... network_type="lan", | ||
| ... train_val_split=0.9, | ||
| ... ) | ||
| """ | ||
| import random | ||
|
|
||
| file_list = [str(f) for f in file_ids] # Ensure strings for consistency | ||
| if shuffle_files: | ||
| random.shuffle(file_list) | ||
|
|
||
| split_idx = int(len(file_list) * train_val_split) | ||
| train_files = file_list[:split_idx] | ||
| valid_files = file_list[split_idx:] | ||
|
|
||
| if len(train_files) == 0: | ||
| raise ValueError( | ||
| f"No training files after split. Got {len(file_list)} files " | ||
| f"with train_val_split={train_val_split}" | ||
| ) | ||
| if len(valid_files) == 0: | ||
| raise ValueError( | ||
| f"No validation files after split. Got {len(file_list)} files " | ||
| f"with train_val_split={train_val_split}" | ||
| ) | ||
|
|
||
| train_dl = make_dataloader( | ||
| file_ids=train_files, | ||
| batch_size=batch_size, | ||
| network_type=network_type, | ||
| label_lower_bound=label_lower_bound, | ||
| shuffle=True, | ||
| num_workers=num_workers, | ||
| pin_memory=pin_memory, | ||
| ) | ||
|
|
||
| valid_dl = make_dataloader( | ||
| file_ids=valid_files, | ||
| batch_size=batch_size, | ||
| network_type=network_type, | ||
| label_lower_bound=label_lower_bound, | ||
| shuffle=True, | ||
| num_workers=num_workers, | ||
| pin_memory=pin_memory, | ||
| ) | ||
|
|
||
| return train_dl, valid_dl, train_dl.dataset.input_dim | ||
|
|
||
|
|
||
| # --- Factory Functions --- | ||
|
|
||
|
|
||
| def TorchMLPFactory( | ||
| network_config: dict | str, | ||
| input_dim: int, | ||
| network_type: str | None = None, | ||
| ) -> "TorchMLP": | ||
| """Factory function to create a TorchMLP object. | ||
|
|
||
| This provides a consistent API with JaxMLPFactory and handles | ||
| loading network configs from pickle files. | ||
|
|
||
| Arguments | ||
| --------- | ||
| network_config: Dictionary containing the network configuration, | ||
| or path to a pickled config file. | ||
| input_dim: Input dimension (typically from dataloader.dataset.input_dim). | ||
| network_type: Network type ("lan", "cpn", "opn"). If not provided, | ||
| will be inferred from train_output_type in network_config. | ||
|
|
||
| Returns | ||
| ------- | ||
| TorchMLP instance ready for training. | ||
|
|
||
| Example | ||
| ------- | ||
| >>> train_dl, valid_dl, input_dim = make_train_valid_dataloaders(...) | ||
| >>> net = TorchMLPFactory( | ||
| ... network_config=network_config, | ||
| ... input_dim=input_dim, | ||
| ... ) | ||
| """ | ||
| if isinstance(network_config, str): | ||
| with open(network_config, "rb") as f: | ||
| network_config = pickle.load(f) | ||
|
|
||
| return TorchMLP( | ||
| network_config=network_config, | ||
| input_shape=input_dim, | ||
| network_type=network_type, | ||
| ) |
| # Load network config from pickle file if string path provided | ||
| if isinstance(network_config, str): | ||
| with open(network_config, "rb") as f: | ||
| self.network_config = pickle.load(f) | ||
| elif isinstance(network_config, dict): | ||
| self.network_config = network_config | ||
| else: | ||
| raise ValueError("network config is neither a string nor a dictionary") | ||
| self.network_config = network_config |
There was a problem hiding this comment.
Pull request overview
This PR combines several "code simplifications" with a new HuggingFace Hub integration module. It renames Jax MLP classes (MLPJax → JaxMLP, MLPJaxFactory → JaxMLPFactory) for consistency, removes unused parameters (generative_model_id, **kwargs, train) from TorchMLP, refactors LoadTorchMLP/LoadTorchMLPInfer into an inheritance hierarchy, adds convenience helper functions for DataLoader creation, introduces a new lanfactory.hf module for uploading/downloading models to HuggingFace Hub, and reorganizes config files with proper aliases.
Changes:
- Renamed Jax MLP classes and removed deprecated parameters from TorchMLP; refactored
LoadTorchMLPInferto inherit fromLoadTorchMLP; addedTorchMLPFactory,make_dataloader, andmake_train_valid_dataloadershelper functions - Added new
lanfactory.hfmodule with upload/download functionality, model card generation, and CLI commands (upload-hf,download-hf) - Updated all tests, notebooks, and documentation to use new APIs and helper functions
Reviewed changes
Copilot reviewed 34 out of 34 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| src/lanfactory/trainers/init.py | Updated exports: new names, added factory/dataloader helpers |
| src/lanfactory/trainers/jax_mlp.py | Renamed MLPJax→JaxMLP, MLPJaxFactory→JaxMLPFactory |
| src/lanfactory/trainers/torch_mlp.py | Removed **kwargs/generative_model_id, added helpers, refactored LoadTorchMLP hierarchy |
| src/lanfactory/config/network_configs.py | Reorganized configs, added backward-compatible aliases |
| src/lanfactory/hf/init.py | New HF module init |
| src/lanfactory/hf/upload.py | New upload functionality |
| src/lanfactory/hf/download.py | New download functionality |
| src/lanfactory/hf/model_card.py | New model card generation |
| src/lanfactory/cli/upload_hf.py | New upload CLI command |
| src/lanfactory/cli/download_hf.py | New download CLI command |
| src/lanfactory/cli/jax_train.py | Updated to use JaxMLPFactory |
| src/lanfactory/onnx/transform_onnx.py | Removed generative_model_id |
| pyproject.toml | Added hf extras, CLI entry points, coverage omits |
| tests/* | Updated all tests for renamed classes and removed params |
| tests/hf/* | New tests for HF module |
| tests/cli/test_hf_cli.py | New CLI smoke tests |
| docs/* | Updated tutorials and added HF documentation |
| mkdocs.yml | Added HF docs to navigation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
src/lanfactory/hf/__init__.py
Outdated
| from lanfactory.hf.upload import upload_model | ||
| from lanfactory.hf.download import download_model | ||
|
|
||
| # Default repository for official HSSM models | ||
| DEFAULT_REPO_ID = "franklab/HSSM" | ||
|
|
| hf = ["huggingface-hub>=0.20.0"] | ||
| all = ["mlflow>=3.6.0", "huggingface-hub>=0.20.0"] |
| make_dataloader, | ||
| make_train_valid_dataloaders, | ||
| ) | ||
| from .jax_mlp import JaxMLPFactory, JaxMLP, ModelTrainerJaxMLP |
There was a problem hiding this comment.
Pull request overview
This PR performs several code simplifications and adds new features: renaming Jax MLP classes from MLPJax/MLPJaxFactory to JaxMLP/JaxMLPFactory, removing unused parameters (generative_model_id, **kwargs) from TorchMLP, refactoring LoadTorchMLP/LoadTorchMLPInfer into an inheritance hierarchy, adding convenience helper functions for DataLoader creation, adding TorchMLPFactory, reorganizing config files, and adding a complete HuggingFace Hub integration module with upload/download CLI commands.
Changes:
- Renames Jax MLP classes to follow
JaxMLP/JaxMLPFactorynaming convention (with deprecation aliases), removes unusedTorchMLPparameters, and refactorsLoadTorchMLPInferto inherit fromLoadTorchMLP - Adds
make_dataloader,make_train_valid_dataloaders, andTorchMLPFactoryhelper functions, with corresponding tutorial updates - Adds
lanfactory.hfmodule withupload_model/download_modelfunctions, model card generation, andupload-hf/download-hfCLI commands
Reviewed changes
Copilot reviewed 35 out of 35 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| src/lanfactory/trainers/jax_mlp.py | Renames MLPJax → JaxMLP, MLPJaxFactory → JaxMLPFactory |
| src/lanfactory/trainers/torch_mlp.py | Removes **kwargs/generative_model_id, adds helpers and factory, refactors LoadTorchMLP hierarchy |
| src/lanfactory/trainers/init.py | Updates exports, adds deprecation aliases for old names |
| src/lanfactory/config/network_configs.py | Reorganizes configs, adds network_config_choice_prob with backward-compatible aliases |
| src/lanfactory/config/init.py | Exports new config names |
| src/lanfactory/hf/init.py | New HuggingFace integration module init |
| src/lanfactory/hf/model_card.py | New model card YAML parsing and README generation |
| src/lanfactory/hf/upload.py | New upload functionality for HuggingFace Hub |
| src/lanfactory/hf/download.py | New download functionality from HuggingFace Hub |
| src/lanfactory/cli/upload_hf.py | New CLI command for uploading models |
| src/lanfactory/cli/download_hf.py | New CLI command for downloading models |
| src/lanfactory/cli/jax_train.py | Updates MLPJaxFactory → JaxMLPFactory reference |
| src/lanfactory/onnx/transform_onnx.py | Removes generative_model_id=None from TorchMLP call |
| pyproject.toml | Adds hf optional dependency, CLI entry points, coverage omissions |
| tests/test_jax_mlp.py | Updates all references to new class names |
| tests/test_torch_mlp.py | Removes generative_model_id=None from test calls |
| tests/test_transform_onnx.py | Removes assertion for removed parameter |
| tests/test_end_to_end_*.py | Updates class name references and removes unused params |
| tests/test_mlflow_integration.py | Updates MLPJaxFactory → JaxMLPFactory |
| tests/hf/ | New test files for upload, download, and model card modules |
| tests/cli/test_hf_cli.py | New CLI smoke tests |
| docs/ | Updated tutorials and new HuggingFace documentation |
| mkdocs.yml | Adds HF docs to navigation |
| notebooks/ | Updates class name references in example notebooks |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| hf = ["huggingface-hub>=0.20.0"] | ||
| all = ["mlflow>=3.6.0", "huggingface-hub>=0.20.0"] |
There was a problem hiding this comment.
Pull request overview
This PR modernizes LANfactory’s training/config/test surface to align with newer ssm-simulators, simplifies model construction/loading APIs, and adds HuggingFace Hub upload/download support (library + CLI), with docs/tests updated accordingly.
Changes:
- Add HuggingFace Hub integration (
lanfactory.hf) plus newupload-hf/download-hfCLI entry points and documentation. - Introduce PyTorch DataLoader convenience helpers and a
TorchMLPFactory; simplifyTorchMLPinitialization and ONNX transform usage. - Rename JAX MLP identifiers to
JaxMLP/JaxMLPFactory(with deprecation aliases atlanfactory.trainers), and update tests/notebooks/docs to match updatedssm-simulatorsconfig structure.
Reviewed changes
Copilot reviewed 36 out of 36 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_transform_onnx.py | Updates assertions for simplified TorchMLP init (removed generative_model_id). |
| tests/test_torch_mlp.py | Updates TorchMLP instantiations to match new signature. |
| tests/test_mlflow_integration.py | Updates JAX factory import/name and adapts data generation config/API to newer ssm-simulators. |
| tests/test_jax_mlp.py | Renames tests/imports to JaxMLP / JaxMLPFactory. |
| tests/test_end_to_end_torch.py | Updates generator config structure + data generator API; removes obsolete train arg. |
| tests/test_end_to_end_jax.py | Updates generator config structure + JAX factory name. |
| tests/hf/test_upload.py | Adds unit tests for HF upload functionality and defaults. |
| tests/hf/test_model_card.py | Adds unit tests for model card YAML loading + README generation. |
| tests/hf/test_download.py | Adds unit tests for HF download behavior, filtering, and defaults. |
| tests/hf/init.py | Adds HF tests package marker. |
| tests/conftest.py | Migrates dummy generator configs to ssms.config.get_default_generator_config and new nested keys. |
| tests/cli/test_hf_cli.py | Adds CLI smoke/validation tests for upload-hf and download-hf. |
| src/lanfactory/trainers/torch_mlp.py | Adds make_dataloader / make_train_valid_dataloaders, adds TorchMLPFactory, refactors loaders (LoadTorchMLP*). |
| src/lanfactory/trainers/jax_mlp.py | Renames MLPJax* to JaxMLP* and updates trainer typing accordingly. |
| src/lanfactory/trainers/init.py | Exposes new helpers/factories; adds deprecation aliases via __getattr__. |
| src/lanfactory/onnx/transform_onnx.py | Removes deprecated generative_model_id arg from TorchMLP construction. |
| src/lanfactory/hf/upload.py | Adds HF upload implementation including file collection and README generation. |
| src/lanfactory/hf/model_card.py | Adds model_card YAML parsing + README frontmatter generation utilities. |
| src/lanfactory/hf/download.py | Adds HF download implementation with include/exclude filtering. |
| src/lanfactory/hf/init.py | Adds HF integration module exports and constants. |
| src/lanfactory/config/network_configs.py | Reorganizes example configs; introduces choice-prob config + backward-compatible aliases. |
| src/lanfactory/config/init.py | Exposes new choice-prob config symbols. |
| src/lanfactory/cli/upload_hf.py | Adds Typer-based upload-hf CLI command. |
| src/lanfactory/cli/download_hf.py | Adds Typer-based download-hf CLI command. |
| src/lanfactory/cli/jax_train.py | Updates JAX factory call site to new name. |
| pyproject.toml | Bumps ssm-simulators, adds Typer + HF extras, registers new CLI scripts, adjusts coverage omit list. |
| notebooks/test_notebooks/test_jax_network_cpn.ipynb | Updates notebook usage to JaxMLPFactory. |
| notebooks/test_notebooks/test_jax_network.ipynb | Updates notebook usage to JaxMLPFactory and traceback text. |
| notebooks/test_notebooks/load_jax_lan_cpn.ipynb | Updates notebook usage to JaxMLPFactory. |
| mkdocs.yml | Adds HF docs + HF API page to nav. |
| docs/using_huggingface.md | Adds user documentation for HF upload/download + model_card.yaml. |
| docs/basic_tutorial/basic_tutorial_opn_torch.ipynb | Switches tutorial to DataLoader helpers and TorchMLPFactory usage. |
| docs/basic_tutorial/basic_tutorial_lan_torch.ipynb | Switches tutorial to DataLoader helpers and TorchMLPFactory usage. |
| docs/basic_tutorial/basic_tutorial_lan_jax.ipynb | Switches tutorial to DataLoader helpers and JaxMLPFactory usage. |
| docs/basic_tutorial/basic_tutorial_cpn_torch.ipynb | Switches tutorial to DataLoader helpers and TorchMLPFactory usage. |
| docs/api/hf.md | Adds MkDocs API stub for lanfactory.hf. |
Comments suppressed due to low confidence (1)
src/lanfactory/trainers/jax_mlp.py:43
JaxMLPFactoryloads pickle configs viapickle.load(open(...))without closing the file handle. Use a context manager (with open(...) as f:) to avoid leaking file descriptors and to match how other loaders in the repo handle pickles.
if isinstance(network_config, str):
network_config_internal = pickle.load(open(network_config, "rb"))
elif isinstance(network_config, dict):
network_config_internal = network_config
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
src/lanfactory/hf/upload.py
Outdated
| # Construct URL | ||
| url = f"https://huggingface.co/{repo_id}/tree/main/{path_in_repo}" |
| @@ -38,6 +38,7 @@ dependencies = [ | |||
| "frozendict>=2.4.6", | |||
| "onnx>=1.17.0", | |||
| "matplotlib>=3.10.1", | |||
| "typer>=0.9.0", | |||
| ] | |||
|
|
|||
| keywords = [ | |||
| @@ -50,6 +51,8 @@ keywords = [ | |||
|
|
|||
| [project.optional-dependencies] | |||
| mlflow = ["mlflow>=3.6.0"] | |||
| hf = ["huggingface-hub>=0.20.0"] | |||
| all = ["mlflow>=3.6.0", "huggingface-hub>=0.20.0"] | |||
|
|
|||
src/lanfactory/trainers/jax_mlp.py
Outdated
| def JaxMLPFactory(network_config: dict | str = {}, train: bool = True) -> "JaxMLP": | ||
| """Factory function to create a JaxMLP object. |
src/lanfactory/trainers/jax_mlp.py
Outdated
| @@ -34,7 +34,7 @@ def MLPJaxFactory(network_config: dict | str = {}, train: bool = True) -> "MLPJa | |||
| Whether the model should be trained or not. | |||
| Returns | |||
| ------- | |||
| MLPJax class initialized with the correct network configuration. | |||
| JaxMLP class initialized with the correct network configuration. | |||
| """ | |||
|
|
|||
| if isinstance(network_config, str): | |||
| @@ -46,15 +46,15 @@ def MLPJaxFactory(network_config: dict | str = {}, train: bool = True) -> "MLPJa | |||
| "network_config argument is not passed as either a dictionary or a string (path to a file)!" | |||
| ) | |||
|
|
|||
| return MLPJax( | |||
| return JaxMLP( | |||
| layer_sizes=network_config_internal["layer_sizes"], | |||
| activations=network_config_internal["activations"], | |||
| train_output_type=network_config_internal["train_output_type"], | |||
| train=train, | |||
| ) | |||
|
|
|||
|
|
|||
| class MLPJax(nn.Module): | |||
| class JaxMLP(nn.Module): | |||
| """JaxMLP class. | |||
src/lanfactory/trainers/torch_mlp.py
Outdated
| batch_size=batch_size, | ||
| network_type=network_type, | ||
| label_lower_bound=label_lower_bound, | ||
| shuffle=True, |
| class LoadTorchMLP: | ||
| """General-purpose class to load TorchMLP models. | ||
|
|
||
| Does NOT call eval() by default - suitable for fine-tuning or further training. | ||
| For inference with eval() enabled, use LoadTorchMLPInfer instead. | ||
|
|
| with open(yaml_path, "r") as f: | ||
| data = yaml.safe_load(f) | ||
|
|
||
| # Extract fields with defaults | ||
| config = ModelCardConfig( | ||
| tags=data.get("tags", ["lan", "ssm", "hssm"]), | ||
| library_name=data.get("library_name", "onnx"), | ||
| license=data.get("license", "mit"), | ||
| title=data.get("title", "LAN Model"), |
Codecov Report❌ Patch coverage is
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
This PR modernizes LANfactory’s training utilities and external integrations by adding HuggingFace Hub upload/download support, introducing Torch DataLoader helper/factory functions, and renaming the JAX MLP API to JaxMLP* while updating tests/docs accordingly.
Changes:
- Added
lanfactory.hfmodule plus newupload-hf/download-hfCLIs and documentation. - Added PyTorch helper functions (
make_dataloader,make_train_valid_dataloaders) and aTorchMLPFactory, and refactored Torch model loaders. - Renamed JAX trainer API from
MLPJax*toJaxMLP*(with deprecation aliases atlanfactory.trainers), and updated tests/tutorial notebooks for the newerssm-simulatorsgenerator config shape.
Reviewed changes
Copilot reviewed 36 out of 36 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_transform_onnx.py | Updates ONNX transform test for removed TorchMLP kwargs. |
| tests/test_torch_mlp.py | Removes deprecated TorchMLP args; adds tests for new dataloader/factory helpers. |
| tests/test_mlflow_integration.py | Updates generator config schema and JAX factory rename in MLflow integration tests. |
| tests/test_jax_mlp.py | Renames imports/usages to JaxMLPFactory/JaxMLP. |
| tests/test_end_to_end_torch.py | Updates ssms generator usage/config schema; removes unsupported TorchMLP arg. |
| tests/test_end_to_end_jax.py | Updates ssms generator usage/config schema; uses JaxMLPFactory. |
| tests/hf/test_upload.py | Adds unit tests for HF upload utilities. |
| tests/hf/test_model_card.py | Adds unit tests for model card YAML loading + README generation. |
| tests/hf/test_download.py | Adds unit tests for HF download utilities. |
| tests/hf/init.py | Adds test package marker for HF tests. |
| tests/conftest.py | Switches to ssms.config.get_default_generator_config and new nested config keys. |
| tests/cli/test_hf_cli.py | Adds subprocess-based smoke tests for new HF CLIs. |
| src/lanfactory/trainers/torch_mlp.py | Adds dataloader helpers + Torch factory; refactors model loaders and TorchMLP signature. |
| src/lanfactory/trainers/jax_mlp.py | Renames MLPJax* to JaxMLP* and updates trainer typing/docs accordingly. |
| src/lanfactory/trainers/init.py | Re-exports new helpers/factories; provides deprecation aliases via __getattr__. |
| src/lanfactory/onnx/transform_onnx.py | Removes deprecated TorchMLP kwarg when exporting to ONNX. |
| src/lanfactory/hf/upload.py | New HF upload implementation (file collection, dry-run, repo upload). |
| src/lanfactory/hf/model_card.py | New model card config + README generation utilities. |
| src/lanfactory/hf/download.py | New HF download implementation (list/filter/download files). |
| src/lanfactory/hf/init.py | New HF package exports/constants. |
| src/lanfactory/config/network_configs.py | Reorganizes config examples; adds choice-prob configs and compatibility aliases. |
| src/lanfactory/config/init.py | Exposes new config names via package exports. |
| src/lanfactory/cli/upload_hf.py | New Typer CLI for uploading models to HF hub. |
| src/lanfactory/cli/download_hf.py | New Typer CLI for downloading models from HF hub. |
| src/lanfactory/cli/jax_train.py | Updates to use JaxMLPFactory. |
| pyproject.toml | Updates dependencies (ssm-simulators bump, typer add, hf extras) and adds console scripts. |
| mkdocs.yml | Adds HF guide/API docs to navigation. |
| docs/using_huggingface.md | New user guide for HF upload/download and model_card.yaml. |
| docs/api/hf.md | Adds API doc stub for lanfactory.hf. |
| notebooks/test_notebooks/test_jax_network_cpn.ipynb | Updates notebook to JaxMLPFactory. |
| notebooks/test_notebooks/test_jax_network.ipynb | Updates notebook to JaxMLPFactory. |
| notebooks/test_notebooks/load_jax_lan_cpn.ipynb | Updates notebook to JaxMLPFactory. |
| docs/basic_tutorial/basic_tutorial_opn_torch.ipynb | Refactors tutorial to use new dataloader helpers and Torch factory. |
| docs/basic_tutorial/basic_tutorial_lan_torch.ipynb | Refactors tutorial to use new dataloader helpers and Torch factory. |
| docs/basic_tutorial/basic_tutorial_lan_jax.ipynb | Refactors tutorial to use new dataloader helpers and JaxMLPFactory. |
| docs/basic_tutorial/basic_tutorial_cpn_torch.ipynb | Refactors tutorial to use new dataloader helpers and Torch factory. |
Comments suppressed due to low confidence (1)
src/lanfactory/trainers/jax_mlp.py:43
- File loading in
JaxMLPFactoryusespickle.load(open(...))without a context manager, which can leak file handles if an exception occurs. Usewith open(network_config, "rb") as f: ...for deterministic cleanup.
if isinstance(network_config, str):
network_config_internal = pickle.load(open(network_config, "rb"))
elif isinstance(network_config, dict):
network_config_internal = network_config
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from lanfactory.hf.model_card import ( # noqa: E402 | ||
| load_model_card_yaml, | ||
| generate_readme, | ||
| ModelCardConfig, | ||
| ) | ||
| from lanfactory.hf.upload import upload_model # noqa: E402 | ||
| from lanfactory.hf.download import download_model # noqa: E402 |
| "frozendict>=2.4.6", | ||
| "onnx>=1.17.0", | ||
| "matplotlib>=3.10.1", | ||
| "typer>=0.9.0", |
| app = typer.Typer() | ||
|
|
||
|
|
||
| @app.command() |
| app = typer.Typer() | ||
|
|
||
|
|
||
| @app.command() | ||
| def main( | ||
| network_type: str = typer.Option( | ||
| ..., | ||
| "--network-type", | ||
| help="Network type: lan, cpn, or opn.", | ||
| ), | ||
| model_name: str = typer.Option( |
| from lanfactory.hf import DEFAULT_REPO_ID, VALID_NETWORK_TYPES | ||
|
|
| from lanfactory.hf import DEFAULT_REPO_ID, VALID_NETWORK_TYPES | ||
|
|
src/lanfactory/trainers/jax_mlp.py
Outdated
| def JaxMLPFactory(network_config: dict | str = {}, train: bool = True) -> "JaxMLP": | ||
| """Factory function to create a JaxMLP object. | ||
| Arguments | ||
| --------- |
| # Backward-compatible aliases | ||
| network_config_cpn = network_config_choice_prob | ||
| network_config_opn = network_config_choice_prob | ||
|
|
No description provided.