From 1ddacc11f261e96e15e2fd699affb6f128f86307 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 16 Jan 2026 18:18:24 +0800 Subject: [PATCH 01/11] some code simplifications around data loaders, and some homogenization of naming conventions --- .../basic_tutorial_cpn_torch.ipynb | 75 ++-- .../basic_tutorial_lan_jax.ipynb | 68 +--- .../basic_tutorial_lan_torch.ipynb | 81 ++-- .../basic_tutorial_opn_torch.ipynb | 81 ++-- .../test_notebooks/load_jax_lan_cpn.ipynb | 4 +- .../test_notebooks/test_jax_network.ipynb | 6 +- .../test_notebooks/test_jax_network_cpn.ipynb | 4 +- pyproject.toml | 2 + src/lanfactory/cli/jax_train.py | 2 +- src/lanfactory/config/network_configs.py | 54 +-- src/lanfactory/onnx/transform_onnx.py | 1 - src/lanfactory/trainers/__init__.py | 26 +- src/lanfactory/trainers/jax_mlp.py | 16 +- src/lanfactory/trainers/torch_mlp.py | 366 +++++++++++++----- tests/test_end_to_end_jax.py | 4 +- tests/test_end_to_end_torch.py | 1 - tests/test_jax_mlp.py | 86 ++-- tests/test_mlflow_integration.py | 4 +- tests/test_torch_mlp.py | 24 +- tests/test_transform_onnx.py | 1 - 20 files changed, 507 insertions(+), 399 deletions(-) diff --git a/docs/basic_tutorial/basic_tutorial_cpn_torch.ipynb b/docs/basic_tutorial/basic_tutorial_cpn_torch.ipynb index 3c26a83..19ae454 100755 --- a/docs/basic_tutorial/basic_tutorial_cpn_torch.ipynb +++ b/docs/basic_tutorial/basic_tutorial_cpn_torch.ipynb @@ -166,60 +166,40 @@ "source": [ "#### Prepare for Training\n", "\n", - "Next we set up dataloaders for training with pytorch. The `LANfactory` uses custom dataloaders, taking into account particularities of the expected training data.\n", - "Specifically, we expect to receive a bunch of training data files (the present example generates only one), where each file hosts a large number of training examples. \n", - "So we want to define a dataloader which spits out batches from data with a specific training data file, and keeps checking when to load in a new file. \n", - "The way this is implemented here, is via the `DatasetTorch` class in `lanfactory.trainers`, which inherits from `torch.utils.data.Dataset` and prespecifies a `batch_size`. Finally this is supplied to a [`DataLoader`](https://pytorch.org/docs/stable/data.html), for which we keep the `batch_size` argument at 0.\n", + "Next we set up dataloaders for training with pytorch. The `LANfactory` provides convenient helper functions for this.\n", "\n", - "The `DatasetTorch` class is then called as an iterator via the DataLoader and takes care of batching as well as file loading internally. \n", + "The `make_train_valid_dataloaders` function handles:\n", + "- Splitting your data files into training and validation sets\n", + "- Creating the appropriate `DatasetTorch` objects\n", + "- Wrapping them in PyTorch `DataLoader` objects with sensible defaults\n", "\n", - "You may choose your own way of defining the `DataLoader` classes, downstream you are simply expected to supply one." + "Under the hood, this uses the `DatasetTorch` class which handles batching and file loading internally." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# MAKE DATALOADERS\n", "TRAINING_TYPE = \"cpn\"\n", "\n", - "# List of datafiles (here only one)\n", + "# List of datafiles\n", "folder_ = Path(\"torch_nb_data/\") / \"cpn_mlp\" / \"training_data\"\n", - "file_list_ = list(folder_.glob(\"*\"))\n", - "\n", - "# Training dataset\n", - "torch_training_dataset = lanfactory.trainers.DatasetTorch(\n", - " file_ids=file_list_,\n", - " batch_size=BATCH_SIZE,\n", - " features_key=f\"{TRAINING_TYPE}_data\",\n", - " label_key=f\"{TRAINING_TYPE}_labels\",\n", - ")\n", - "\n", - "torch_training_dataloader = torch.utils.data.DataLoader(\n", - " torch_training_dataset,\n", - " shuffle=True,\n", - " batch_size=None,\n", - " num_workers=1,\n", - " pin_memory=True,\n", - ")\n", + "file_list_ = list(folder_.glob(\"*.pickle\"))\n", "\n", - "# Validation dataset\n", - "torch_validation_dataset = lanfactory.trainers.DatasetTorch(\n", + "# Create train and validation dataloaders with a single function call\n", + "torch_training_dataloader, torch_validation_dataloader, input_dim = lanfactory.trainers.make_train_valid_dataloaders(\n", " file_ids=file_list_,\n", " batch_size=BATCH_SIZE,\n", - " features_key=f\"{TRAINING_TYPE}_data\",\n", - " label_key=f\"{TRAINING_TYPE}_labels\",\n", + " network_type=\"cpn\",\n", + " train_val_split=0.9,\n", ")\n", "\n", - "torch_validation_dataloader = torch.utils.data.DataLoader(\n", - " torch_validation_dataset,\n", - " shuffle=True,\n", - " batch_size=None,\n", - " num_workers=1,\n", - " pin_memory=True,\n", - ")" + "print(f\"Training batches: {len(torch_training_dataloader)}\")\n", + "print(f\"Validation batches: {len(torch_validation_dataloader)}\")\n", + "print(f\"Input dimension: {input_dim}\")" ] }, { @@ -275,7 +255,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -302,20 +282,26 @@ ], "source": [ "# LOAD NETWORK\n", - "net = lanfactory.trainers.TorchMLP(\n", + "# Option 1: Using the factory function (recommended)\n", + "net = lanfactory.trainers.TorchMLPFactory(\n", " network_config=deepcopy(network_config),\n", - " input_shape=torch_training_dataset.input_dim,\n", - " save_folder=Path(\"torch_nb_data/torch_models\") / TRAINING_TYPE / MODEL,\n", - " generative_model_id=MODEL,\n", + " input_dim=input_dim,\n", + " network_type=\"cpn\",\n", ")\n", "\n", + "# Option 2: Direct instantiation (also works)\n", + "# net = lanfactory.trainers.TorchMLP(\n", + "# network_config=deepcopy(network_config),\n", + "# input_shape=input_dim,\n", + "# network_type=\"cpn\",\n", + "# )\n", + "\n", "# SAVE CONFIGS\n", "lanfactory.utils.save_configs(\n", " model_id=MODEL + \"_torch_\",\n", " save_folder=Path(\"torch_nb_data/torch_models\") / TRAINING_TYPE / MODEL,\n", " network_config=network_config,\n", " train_config=train_config,\n", - " allow_abs_path_folder_generation=True,\n", ")" ] }, @@ -465,7 +451,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -482,10 +468,11 @@ "model_path = Path(\"torch_nb_data/torch_models\") / TRAINING_TYPE / MODEL\n", "network_file_path = next(model_path.glob(\"*state_dict*\"))\n", "\n", + "# LoadTorchMLPInfer automatically calls eval() for inference mode\n", "network = lanfactory.trainers.LoadTorchMLPInfer(\n", " model_file_path=network_file_path,\n", " network_config=network_config,\n", - " input_dim=torch_training_dataset.input_dim,\n", + " input_dim=input_dim,\n", " network_type=TRAINING_TYPE,\n", ")" ] diff --git a/docs/basic_tutorial/basic_tutorial_lan_jax.ipynb b/docs/basic_tutorial/basic_tutorial_lan_jax.ipynb index bdb000d..f977def 100644 --- a/docs/basic_tutorial/basic_tutorial_lan_jax.ipynb +++ b/docs/basic_tutorial/basic_tutorial_lan_jax.ipynb @@ -192,66 +192,38 @@ "source": [ "#### Prepare for Training\n", "\n", - "Next we set up dataloaders for training with pytorch. The `LANfactory` uses custom dataloaders, taking into account particularities of the expected training data.\n", - "Specifically, we expect to receive a bunch of training data files (the present example generates only one), where each file hosts a large number of training examples. \n", - "So we want to define a dataloader which spits out batches from data with a specific training data file, and keeps checking when to load in a new file. \n", - "The way this is implemented here, is via the `DatasetTorch` class in `lanfactory.trainers`, which inherits from `torch.utils.data.Dataset` and prespecifies a `batch_size`. Finally this is supplied to a [`DataLoader`](https://pytorch.org/docs/stable/data.html), for which we keep the `batch_size` argument at 0.\n", + "Next we set up dataloaders for training. The `LANfactory` provides convenient helper functions for this.\n", "\n", - "The `DatasetTorch` class is then called as an iterator via the DataLoader and takes care of batching as well as file loading internally. \n", + "The `make_train_valid_dataloaders` function handles:\n", + "- Splitting your data files into training and validation sets\n", + "- Creating the appropriate `DatasetTorch` objects\n", + "- Wrapping them in PyTorch `DataLoader` objects with sensible defaults\n", "\n", - "You may choose your own way of defining the `DataLoader` classes, downstream you are simply expected to supply one." + "The data is returned as numpy arrays, which JAX handles seamlessly." ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "# folder_ = OUT_FOLDER\n", - "file_list_ = [os.path.join(OUT_FOLDER, file_) for file_ in os.listdir(OUT_FOLDER)]\n", + "# MAKE DATALOADERS\n", "\n", + "# List of datafiles\n", + "file_list_ = list(OUT_FOLDER.glob(\"*.pickle\"))\n", "\n", - "# INDEPENDENT TESTS OF DATALOADERS\n", - "# Training dataset\n", - "jax_training_dataset = lanfactory.trainers.DatasetTorch(\n", + "# Create train and validation dataloaders with a single function call\n", + "jax_training_dataloader, jax_validation_dataloader, input_dim = lanfactory.trainers.make_train_valid_dataloaders(\n", " file_ids=file_list_,\n", - " batch_size=(\n", - " train_config[DEVICE + \"_batch_size\"]\n", - " if torch.cuda.is_available()\n", - " else train_config[DEVICE + \"_batch_size\"]\n", - " ),\n", - " label_lower_bound=np.log(1e-10),\n", - " features_key=\"lan_data\",\n", - " label_key=\"lan_labels\",\n", - " out_framework=\"jax\",\n", + " batch_size=train_config[\"cpu_batch_size\"],\n", + " network_type=\"lan\",\n", + " train_val_split=0.9,\n", ")\n", "\n", - "jax_training_dataloader = torch.utils.data.DataLoader(\n", - " jax_training_dataset, shuffle=True, batch_size=None, num_workers=1, pin_memory=True\n", - ")\n", - "\n", - "# Validation dataset\n", - "jax_validation_dataset = lanfactory.trainers.DatasetTorch(\n", - " file_ids=file_list_,\n", - " batch_size=(\n", - " train_config[DEVICE + \"_batch_size\"]\n", - " if torch.cuda.is_available()\n", - " else train_config[DEVICE + \"_batch_size\"]\n", - " ),\n", - " label_lower_bound=np.log(1e-10),\n", - " features_key=\"lan_data\",\n", - " label_key=\"lan_labels\",\n", - " out_framework=\"jax\",\n", - ")\n", - "\n", - "jax_validation_dataloader = torch.utils.data.DataLoader(\n", - " jax_validation_dataset,\n", - " shuffle=True,\n", - " batch_size=None,\n", - " num_workers=1,\n", - " pin_memory=True,\n", - ")" + "print(f\"Training batches: {len(jax_training_dataloader)}\")\n", + "print(f\"Validation batches: {len(jax_validation_dataloader)}\")\n", + "print(f\"Input dimension: {input_dim}\")" ] }, { @@ -269,7 +241,7 @@ "source": [ "# LOAD NETWORK\n", "# Test properties of network\n", - "jax_net = lanfactory.trainers.MLPJaxFactory(network_config=network_config, train=True)\n", + "jax_net = lanfactory.trainers.JaxMLPFactory(network_config=network_config, train=True)\n", "\n", "# Save model config\n", "# model_folder = os.path.join(\"data\", \"jax_models\", MODEL)\n", @@ -386,7 +358,7 @@ "# Loaded Net\n", "# Test passing network config as path and as object\n", "\n", - "jax_infer = lanfactory.trainers.MLPJaxFactory(\n", + "jax_infer = lanfactory.trainers.JaxMLPFactory(\n", "\t network_config=network_config,\n", " train=False,\n", " )" diff --git a/docs/basic_tutorial/basic_tutorial_lan_torch.ipynb b/docs/basic_tutorial/basic_tutorial_lan_torch.ipynb index d01cecd..aad7f44 100755 --- a/docs/basic_tutorial/basic_tutorial_lan_torch.ipynb +++ b/docs/basic_tutorial/basic_tutorial_lan_torch.ipynb @@ -211,64 +211,39 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Next we set up dataloaders for training with pytorch. The `LANfactory` uses custom dataloaders, taking into account particularities of the expected training data.\n", - "Specifically, we expect to receive a bunch of training data files (the present example generates only one), where each file hosts a large number of training examples. \n", - "So we want to define a dataloader which spits out batches from data with a specific training data file, and keeps checking when to load in a new file. \n", - "The way this is implemented here, is via the `DatasetTorch` class in `lanfactory.trainers`, which inherits from `torch.utils.data.Dataset` and prespecifies a `batch_size`. Finally this is supplied to a [`DataLoader`](https://pytorch.org/docs/stable/data.html), for which we keep the `batch_size` argument at 0.\n", + "Next we set up dataloaders for training with pytorch. The `LANfactory` provides convenient helper functions for this.\n", "\n", - "The `DatasetTorch` class is then called as an iterator via the DataLoader and takes care of batching as well as file loading internally. \n", + "The `make_train_valid_dataloaders` function handles:\n", + "- Splitting your data files into training and validation sets\n", + "- Creating the appropriate `DatasetTorch` objects\n", + "- Wrapping them in PyTorch `DataLoader` objects with sensible defaults\n", "\n", - "You may choose your own way of defining the `DataLoader` classes, downstream you are simply expected to supply one." + "Under the hood, this uses the `DatasetTorch` class which handles batching and file loading internally." ] }, { "cell_type": "code", - "execution_count": 45, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "# \n", - "TRAINING_TYPE = 'lan' # 'lan', 'cpn', 'opn'\n", - "\n", "# MAKE DATALOADERS\n", "\n", - "# List of datafiles (here only one)\n", - "folder_ = Path(\"torch_nb_data/lan_mlp\") / MODEL # + \"/training_data_0_nbins_0_n_1000/\"\n", - "file_list_ = [str(p) for p in folder_.iterdir()]\n", + "# List of datafiles\n", + "folder_ = Path(\"torch_nb_data/lan_mlp\") / MODEL\n", + "file_list_ = list(folder_.glob(\"*.pickle\"))\n", "\n", - "# Training dataset\n", - "torch_training_dataset = lanfactory.trainers.DatasetTorch(\n", + "# Create train and validation dataloaders with a single function call\n", + "torch_training_dataloader, torch_validation_dataloader, input_dim = lanfactory.trainers.make_train_valid_dataloaders(\n", " file_ids=file_list_,\n", " batch_size=train_config[\"cpu_batch_size\"],\n", - " features_key=f\"{TRAINING_TYPE}_data\",\n", - " label_key=f\"{TRAINING_TYPE}_labels\",\n", - " label_lower_bound=np.log(1e-10)\n", - ")\n", - "\n", - "torch_training_dataloader = torch.utils.data.DataLoader(\n", - " torch_training_dataset,\n", - " shuffle=True,\n", - " batch_size=None,\n", - " num_workers=1,\n", - " pin_memory=True,\n", - ")\n", - "\n", - "# Validation dataset\n", - "torch_validation_dataset = lanfactory.trainers.DatasetTorch(\n", - " file_ids=file_list_,\n", - " batch_size=train_config[\"cpu_batch_size\"],\n", - " features_key=f\"{TRAINING_TYPE}_data\",\n", - " label_key=f\"{TRAINING_TYPE}_labels\",\n", - " label_lower_bound=np.log(1e-10)\n", + " network_type=\"lan\",\n", + " train_val_split=0.9,\n", ")\n", "\n", - "torch_validation_dataloader = torch.utils.data.DataLoader(\n", - " torch_validation_dataset,\n", - " shuffle=True,\n", - " batch_size=None,\n", - " num_workers=1,\n", - " pin_memory=True,\n", - ")" + "print(f\"Training batches: {len(torch_training_dataloader)}\")\n", + "print(f\"Validation batches: {len(torch_validation_dataloader)}\")\n", + "print(f\"Input dimension: {input_dim}\")" ] }, { @@ -280,7 +255,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -304,21 +279,26 @@ ], "source": [ "# LOAD NETWORK\n", - "net = lanfactory.trainers.TorchMLP(\n", + "# Option 1: Using the factory function (recommended)\n", + "net = lanfactory.trainers.TorchMLPFactory(\n", " network_config=deepcopy(network_config),\n", + " input_dim=input_dim,\n", " network_type=\"lan\",\n", - " input_shape=torch_training_dataset.input_dim,\n", - " save_folder=str(Path(\"/torch_nb_data/torch_models/\")),\n", - " generative_model_id=MODEL,\n", ")\n", "\n", + "# Option 2: Direct instantiation (also works)\n", + "# net = lanfactory.trainers.TorchMLP(\n", + "# network_config=deepcopy(network_config),\n", + "# input_shape=input_dim,\n", + "# network_type=\"lan\",\n", + "# )\n", + "\n", "# SAVE CONFIGS\n", "lanfactory.utils.save_configs(\n", " model_id=MODEL + \"_torch_\",\n", " save_folder=str(Path(\"torch_nb_data/torch_models/lan/\" + MODEL + \"/\")),\n", " network_config=network_config,\n", " train_config=train_config,\n", - " allow_abs_path_folder_generation=True,\n", ")" ] }, @@ -514,7 +494,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -534,10 +514,11 @@ "model_path = Path(\"torch_nb_data/torch_models\") / (MODEL + \"_lan\")\n", "network_file_path = next(model_path.glob(\"*state_dict*\"))\n", "\n", + "# LoadTorchMLPInfer automatically calls eval() for inference mode\n", "network = lanfactory.trainers.LoadTorchMLPInfer(\n", " model_file_path=network_file_path,\n", " network_config=network_config,\n", - " input_dim=torch_training_dataset.input_dim,\n", + " input_dim=input_dim,\n", ")" ] }, diff --git a/docs/basic_tutorial/basic_tutorial_opn_torch.ipynb b/docs/basic_tutorial/basic_tutorial_opn_torch.ipynb index d199c10..9e4d1f2 100755 --- a/docs/basic_tutorial/basic_tutorial_opn_torch.ipynb +++ b/docs/basic_tutorial/basic_tutorial_opn_torch.ipynb @@ -177,60 +177,41 @@ "source": [ "#### Prepare for Training\n", "\n", - "Next we set up dataloaders for training with pytorch. The `LANfactory` uses custom dataloaders, taking into account particularities of the expected training data.\n", - "Specifically, we expect to receive a bunch of training data files (the present example generates only one), where each file hosts a large number of training examples. \n", - "So we want to define a dataloader which spits out batches from data with a specific training data file, and keeps checking when to load in a new file. \n", - "The way this is implemented here, is via the `DatasetTorch` class in `lanfactory.trainers`, which inherits from `torch.utils.data.Dataset` and prespecifies a `batch_size`. Finally this is supplied to a [`DataLoader`](https://pytorch.org/docs/stable/data.html), for which we keep the `batch_size` argument at 0.\n", + "Next we set up dataloaders for training with pytorch. The `LANfactory` provides convenient helper functions for this.\n", "\n", - "The `DatasetTorch` class is then called as an iterator via the DataLoader and takes care of batching as well as file loading internally. \n", + "The `make_train_valid_dataloaders` function handles:\n", + "- Splitting your data files into training and validation sets\n", + "- Creating the appropriate `DatasetTorch` objects\n", + "- Wrapping them in PyTorch `DataLoader` objects with sensible defaults\n", "\n", - "You may choose your own way of defining the `DataLoader` classes, downstream you are simply expected to supply one." + "Under the hood, this uses the `DatasetTorch` class which handles batching and file loading internally." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# MAKE DATALOADERS\n", "TRAINING_TYPE = \"opn\"\n", + "BATCH_SIZE = 2048\n", "\n", - "# List of datafiles (here only one)\n", + "# List of datafiles\n", "folder_ = Path(\"torch_nb_data/\") / \"opn_mlp\" / \"training_data\"\n", - "file_list_ = list(folder_.glob(\"*\"))\n", + "file_list_ = list(folder_.glob(\"*.pickle\"))\n", "\n", - "# Training dataset\n", - "torch_training_dataset = lanfactory.trainers.DatasetTorch(\n", + "# Create train and validation dataloaders with a single function call\n", + "torch_training_dataloader, torch_validation_dataloader, input_dim = lanfactory.trainers.make_train_valid_dataloaders(\n", " file_ids=file_list_,\n", - " batch_size=2048,\n", - " features_key=f\"{TRAINING_TYPE}_data\",\n", - " label_key=f\"{TRAINING_TYPE}_labels\",\n", + " batch_size=BATCH_SIZE,\n", + " network_type=\"opn\",\n", + " train_val_split=0.9,\n", ")\n", "\n", - "torch_training_dataloader = torch.utils.data.DataLoader(\n", - " torch_training_dataset,\n", - " shuffle=True,\n", - " batch_size=None,\n", - " num_workers=1,\n", - " pin_memory=True,\n", - ")\n", - "\n", - "# Validation dataset\n", - "torch_validation_dataset = lanfactory.trainers.DatasetTorch(\n", - " file_ids=file_list_,\n", - " batch_size=2048,\n", - " features_key=f\"{TRAINING_TYPE}_data\",\n", - " label_key=f\"{TRAINING_TYPE}_labels\",\n", - ")\n", - "\n", - "torch_validation_dataloader = torch.utils.data.DataLoader(\n", - " torch_validation_dataset,\n", - " shuffle=True,\n", - " batch_size=None,\n", - " num_workers=1,\n", - " pin_memory=True,\n", - ")" + "print(f\"Training batches: {len(torch_training_dataloader)}\")\n", + "print(f\"Validation batches: {len(torch_validation_dataloader)}\")\n", + "print(f\"Input dimension: {input_dim}\")" ] }, { @@ -287,7 +268,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -313,21 +294,27 @@ } ], "source": [ - "net = lanfactory.trainers.TorchMLP(\n", + "# LOAD NETWORK\n", + "# Option 1: Using the factory function (recommended)\n", + "net = lanfactory.trainers.TorchMLPFactory(\n", " network_config=deepcopy(network_config),\n", - " input_shape=torch_training_dataset.input_dim,\n", - " save_folder=Path(\"torch_nb_data/torch_models\") / TRAINING_TYPE / MODEL,\n", - " generative_model_id=MODEL,\n", - " network_type=TRAINING_TYPE,\n", + " input_dim=input_dim,\n", + " network_type=\"opn\",\n", ")\n", "\n", + "# Option 2: Direct instantiation (also works)\n", + "# net = lanfactory.trainers.TorchMLP(\n", + "# network_config=deepcopy(network_config),\n", + "# input_shape=input_dim,\n", + "# network_type=\"opn\",\n", + "# )\n", + "\n", "# SAVE CONFIGS\n", "lanfactory.utils.save_configs(\n", " model_id=MODEL + \"_torch_\",\n", " save_folder=Path(\"torch_nb_data/torch_models\") / TRAINING_TYPE / MODEL,\n", " network_config=network_config,\n", " train_config=train_config,\n", - " allow_abs_path_folder_generation=True,\n", ")" ] }, @@ -485,7 +472,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -504,11 +491,11 @@ "model_path = Path(\"torch_nb_data/torch_models\") / TRAINING_TYPE / MODEL\n", "network_file_path = next(model_path.glob(\"*state_dict*\"))\n", "\n", - "\n", + "# LoadTorchMLPInfer automatically calls eval() for inference mode\n", "network = lanfactory.trainers.LoadTorchMLPInfer(\n", " model_file_path=network_file_path,\n", " network_config=network_config,\n", - " input_dim=torch_training_dataset.input_dim,\n", + " input_dim=input_dim,\n", " network_type=TRAINING_TYPE,\n", ")" ] diff --git a/notebooks/test_notebooks/load_jax_lan_cpn.ipynb b/notebooks/test_notebooks/load_jax_lan_cpn.ipynb index 284c1f4..aa35d73 100644 --- a/notebooks/test_notebooks/load_jax_lan_cpn.ipynb +++ b/notebooks/test_notebooks/load_jax_lan_cpn.ipynb @@ -386,7 +386,7 @@ ], "source": [ "# LAN\n", - "jax_infer_lan = lanfactory.trainers.MLPJaxFactory(\n", + "jax_infer_lan = lanfactory.trainers.JaxMLPFactory(\n", " network_config=network_config_lan,\n", " train=False,\n", ")\n", @@ -399,7 +399,7 @@ ")\n", "\n", "# # CPN\n", - "jax_infer_cpn = lanfactory.trainers.MLPJaxFactory(\n", + "jax_infer_cpn = lanfactory.trainers.JaxMLPFactory(\n", " network_config=network_config_cpn,\n", " train=False,\n", ")\n", diff --git a/notebooks/test_notebooks/test_jax_network.ipynb b/notebooks/test_notebooks/test_jax_network.ipynb index ecc5108..6557a06 100755 --- a/notebooks/test_notebooks/test_jax_network.ipynb +++ b/notebooks/test_notebooks/test_jax_network.ipynb @@ -324,7 +324,7 @@ "outputs": [], "source": [ "# LOAD NETWORK\n", - "jax_net = lanfactory.trainers.MLPJaxFactory(network_config=network_config, train=True)\n", + "jax_net = lanfactory.trainers.JaxMLPFactory(network_config=network_config, train=True)\n", "pickle.dump(\n", " network_config,\n", " open(\"../data/jax_models/\" + MODEL + \"/jax_network_config.pickle\", \"wb\"),\n", @@ -1070,7 +1070,7 @@ "outputs": [], "source": [ "# Loaded Net\n", - "jax_infer = lanfactory.trainers.MLPJaxFactory(\n", + "jax_infer = lanfactory.trainers.JaxMLPFactory(\n", " # network_config=\"../data/jax_models/\"\n", " # + MODEL\n", " # + \"/\"\n", @@ -1096,7 +1096,7 @@ "Cell \u001b[0;32mIn[22], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m forward_pass \u001b[38;5;241m=\u001b[39m \u001b[43mjax_infer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmake_forward_partial\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/lanfactory/lib/python3.11/site-packages/flax/linen/module.py:699\u001b[0m, in \u001b[0;36mwrap_method_once..wrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 697\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m args \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(args[\u001b[38;5;241m0\u001b[39m], Module):\n\u001b[1;32m 698\u001b[0m \u001b[38;5;28mself\u001b[39m, args \u001b[38;5;241m=\u001b[39m args[\u001b[38;5;241m0\u001b[39m], args[\u001b[38;5;241m1\u001b[39m:]\n\u001b[0;32m--> 699\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_wrapped_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 700\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 701\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fun(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", "File \u001b[0;32m~/miniconda3/envs/lanfactory/lib/python3.11/site-packages/flax/linen/module.py:1216\u001b[0m, in \u001b[0;36mModule._call_wrapped_method\u001b[0;34m(self, fun, args, kwargs)\u001b[0m\n\u001b[1;32m 1214\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _use_named_call:\n\u001b[1;32m 1215\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m jax\u001b[38;5;241m.\u001b[39mnamed_scope(_derive_profiling_name(\u001b[38;5;28mself\u001b[39m, fun)):\n\u001b[0;32m-> 1216\u001b[0m y \u001b[38;5;241m=\u001b[39m \u001b[43mrun_fun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1217\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1218\u001b[0m y \u001b[38;5;241m=\u001b[39m run_fun(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", - "File \u001b[0;32m~/Library/CloudStorage/OneDrive-Personal/project_lanfactory/LANfactory/lanfactory/trainers/jax_mlp.py:221\u001b[0m, in \u001b[0;36mMLPJax.make_forward_partial\u001b[0;34m(self, seed, input_dim, state, add_jitted)\u001b[0m\n\u001b[1;32m 219\u001b[0m loaded_state \u001b[38;5;241m=\u001b[39m state\n\u001b[1;32m 220\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 221\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstate argument has to be a dictionary or a string!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 223\u001b[0m \u001b[38;5;66;03m# Make forward pass\u001b[39;00m\n\u001b[1;32m 224\u001b[0m net_forward \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapply, loaded_state)\n", + "File \u001b[0;32m~/Library/CloudStorage/OneDrive-Personal/project_lanfactory/LANfactory/lanfactory/trainers/jax_mlp.py:221\u001b[0m, in \u001b[0;36mJaxMLP.make_forward_partial\u001b[0;34m(self, seed, input_dim, state, add_jitted)\u001b[0m\n\u001b[1;32m 219\u001b[0m loaded_state \u001b[38;5;241m=\u001b[39m state\n\u001b[1;32m 220\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 221\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstate argument has to be a dictionary or a string!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 223\u001b[0m \u001b[38;5;66;03m# Make forward pass\u001b[39;00m\n\u001b[1;32m 224\u001b[0m net_forward \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapply, loaded_state)\n", "\u001b[0;31mValueError\u001b[0m: state argument has to be a dictionary or a string!" ] } diff --git a/notebooks/test_notebooks/test_jax_network_cpn.ipynb b/notebooks/test_notebooks/test_jax_network_cpn.ipynb index daa01a2..13756d8 100644 --- a/notebooks/test_notebooks/test_jax_network_cpn.ipynb +++ b/notebooks/test_notebooks/test_jax_network_cpn.ipynb @@ -417,7 +417,7 @@ "outputs": [], "source": [ "# LOAD NETWORK\n", - "jax_net = lanfactory.trainers.MLPJaxFactory(network_config=network_config, train=True)\n", + "jax_net = lanfactory.trainers.JaxMLPFactory(network_config=network_config, train=True)\n", "pickle.dump(\n", " network_config,\n", " open(\n", @@ -563,7 +563,7 @@ "outputs": [], "source": [ "# Loaded Net\n", - "jax_infer = lanfactory.trainers.MLPJaxFactory(\n", + "jax_infer = lanfactory.trainers.JaxMLPFactory(\n", " network_config=network_config,\n", " train=False,\n", ")" diff --git a/pyproject.toml b/pyproject.toml index ab58272..5709a81 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,8 @@ omit = [ "*/tests/*", "*/__init__.py", "*/conftest.py", + "*/cli/jax_train.py", # CLI entry points are tested via smoke tests + "*/cli/torch_train.py", # CLI entry points are tested via smoke tests ] [tool.coverage.report] diff --git a/src/lanfactory/cli/jax_train.py b/src/lanfactory/cli/jax_train.py index 97e1b6b..ffde030 100644 --- a/src/lanfactory/cli/jax_train.py +++ b/src/lanfactory/cli/jax_train.py @@ -472,7 +472,7 @@ def main( ) # Load network - net = lanfactory.trainers.MLPJaxFactory( + net = lanfactory.trainers.JaxMLPFactory( network_config=deepcopy(network_config), train=True, ) diff --git a/src/lanfactory/config/network_configs.py b/src/lanfactory/config/network_configs.py index 1a39144..ab4dc07 100755 --- a/src/lanfactory/config/network_configs.py +++ b/src/lanfactory/config/network_configs.py @@ -1,27 +1,42 @@ """This Module defines simple examples for network and training configurations that serve as inputs to the training classes in the package. + +The configs are organized as follows: +- network_config_mlp / train_config_mlp: For LAN (likelihood approximation networks) +- network_config_choice_prob / train_config_choice_prob: For CPN/OPN (choice probability networks) + +For backward compatibility, network_config_cpn, network_config_opn, train_config_cpn, +and train_config_opn are provided as aliases to the choice_prob configs. """ -network_config_cpn = { +# --- Network Configurations --- + +# LAN (Likelihood Approximation Network) config +# Output type: logprob (log-probabilities) +network_config_mlp = { "layer_sizes": [100, 100, 1], "activations": ["tanh", "tanh", "linear"], - "train_output_type": "logits", + "train_output_type": "logprob", } -network_config_opn = { +# Choice Probability Network config (used for both CPN and OPN) +# Output type: logits (transformed to log-probabilities during inference) +network_config_choice_prob = { "layer_sizes": [100, 100, 1], "activations": ["tanh", "tanh", "linear"], "train_output_type": "logits", } -network_config_mlp = { - "layer_sizes": [100, 100, 1], - "activations": ["tanh", "tanh", "linear"], - "train_output_type": "logprob", -} +# Backward-compatible aliases +network_config_cpn = network_config_choice_prob +network_config_opn = network_config_choice_prob -train_config_cpn = { +# --- Training Configurations --- + +# LAN training config +# Loss: huber (for log-probability regression) +train_config_mlp = { "cpu_batch_size": 256, "gpu_batch_size": 512, "n_epochs": 5, @@ -30,11 +45,13 @@ "lr_scheduler": "reduce_on_plateau", "lr_scheduler_params": {}, "weight_decay": 0.0, - "loss": "bcelogit", + "loss": "huber", "save_history": True, } -train_config_opn = { +# Choice Probability Network training config (used for both CPN and OPN) +# Loss: bcelogit (binary cross-entropy with logits) +train_config_choice_prob = { "cpu_batch_size": 256, "gpu_batch_size": 512, "n_epochs": 5, @@ -47,15 +64,6 @@ "save_history": True, } -train_config_mlp = { - "cpu_batch_size": 256, - "gpu_batch_size": 512, - "n_epochs": 5, - "optimizer": "adam", - "learning_rate": 0.002, - "lr_scheduler": "reduce_on_plateau", - "lr_scheduler_params": {}, - "weight_decay": 0.0, - "loss": "huber", - "save_history": True, -} +# Backward-compatible aliases +train_config_cpn = train_config_choice_prob +train_config_opn = train_config_choice_prob diff --git a/src/lanfactory/onnx/transform_onnx.py b/src/lanfactory/onnx/transform_onnx.py index 00fa95f..43d87c8 100755 --- a/src/lanfactory/onnx/transform_onnx.py +++ b/src/lanfactory/onnx/transform_onnx.py @@ -36,7 +36,6 @@ def transform_to_onnx( mynet = TorchMLP( network_config=network_config_mlp, input_shape=input_shape, - generative_model_id=None, ) mynet.load_state_dict( diff --git a/src/lanfactory/trainers/__init__.py b/src/lanfactory/trainers/__init__.py index 806e33c..d2780b5 100755 --- a/src/lanfactory/trainers/__init__.py +++ b/src/lanfactory/trainers/__init__.py @@ -1,14 +1,28 @@ -from .torch_mlp import DatasetTorch, TorchMLP -from .torch_mlp import ModelTrainerTorchMLP, LoadTorchMLPInfer, LoadTorchMLP -from .jax_mlp import MLPJaxFactory, MLPJax, ModelTrainerJaxMLP +from .torch_mlp import ( + DatasetTorch, + TorchMLP, + TorchMLPFactory, + ModelTrainerTorchMLP, + LoadTorchMLP, + LoadTorchMLPInfer, + make_dataloader, + make_train_valid_dataloaders, +) +from .jax_mlp import JaxMLPFactory, JaxMLP, ModelTrainerJaxMLP __all__ = [ + # Dataset and DataLoader helpers "DatasetTorch", + "make_dataloader", + "make_train_valid_dataloaders", + # Torch MLP "TorchMLP", + "TorchMLPFactory", "ModelTrainerTorchMLP", - "LoadTorchMLPInfer", "LoadTorchMLP", - "MLPJaxFactory", - "MLPJax", + "LoadTorchMLPInfer", + # Jax MLP + "JaxMLPFactory", + "JaxMLP", "ModelTrainerJaxMLP", ] diff --git a/src/lanfactory/trainers/jax_mlp.py b/src/lanfactory/trainers/jax_mlp.py index e7f10d7..d2df40a 100755 --- a/src/lanfactory/trainers/jax_mlp.py +++ b/src/lanfactory/trainers/jax_mlp.py @@ -24,8 +24,8 @@ print("mlflow not available") -def MLPJaxFactory(network_config: dict | str = {}, train: bool = True) -> "MLPJax": - """Factory function to create a MLPJax object. +def JaxMLPFactory(network_config: dict | str = {}, train: bool = True) -> "JaxMLP": + """Factory function to create a JaxMLP object. Arguments --------- network_config (dict | str): @@ -34,7 +34,7 @@ def MLPJaxFactory(network_config: dict | str = {}, train: bool = True) -> "MLPJa Whether the model should be trained or not. Returns ------- - MLPJax class initialized with the correct network configuration. + JaxMLP class initialized with the correct network configuration. """ if isinstance(network_config, str): @@ -46,7 +46,7 @@ def MLPJaxFactory(network_config: dict | str = {}, train: bool = True) -> "MLPJa "network_config argument is not passed as either a dictionary or a string (path to a file)!" ) - return MLPJax( + return JaxMLP( layer_sizes=network_config_internal["layer_sizes"], activations=network_config_internal["activations"], train_output_type=network_config_internal["train_output_type"], @@ -54,7 +54,7 @@ def MLPJaxFactory(network_config: dict | str = {}, train: bool = True) -> "MLPJa ) -class MLPJax(nn.Module): +class JaxMLP(nn.Module): """JaxMLP class. Arguments --------- @@ -233,7 +233,7 @@ class ModelTrainerJaxMLP: def __init__( self, train_config: dict, - model: MLPJax, + model: JaxMLP, train_dl: Any, valid_dl: Any, allow_abs_path_folder_generation: bool = False, @@ -246,8 +246,8 @@ def __init__( --------- train_config (dict): Dictionary containing the training configuration. - model (MLPJax): - The MLPJax model to be trained. + model (JaxMLP): + The JaxMLP model to be trained. train_dl (torch.utils.data.DataLoader): The training data loader. valid_dl (torch.utils.data.DataLoader): diff --git a/src/lanfactory/trainers/torch_mlp.py b/src/lanfactory/trainers/torch_mlp.py index cb388bd..a128709 100755 --- a/src/lanfactory/trainers/torch_mlp.py +++ b/src/lanfactory/trainers/torch_mlp.py @@ -154,6 +154,198 @@ def __data_generation( return X, y +# --- Helper Functions for DataLoader Creation --- + + +def make_dataloader( + file_ids: list[str] | list[Path], + batch_size: int, + network_type: str = "lan", + label_lower_bound: float | None = None, + shuffle: bool = True, + num_workers: int = 1, + pin_memory: bool = True, +) -> DataLoader: + """Create a DataLoader for LAN/CPN/OPN training. + + This is a convenience function that creates a DatasetTorch and wraps it + in a PyTorch DataLoader with sensible defaults. + + Arguments + --------- + file_ids: List of paths to training data pickle files. + batch_size: Batch size for training. + network_type: Type of network ("lan", "cpn", or "opn"). + Determines the feature/label keys in the data files. + label_lower_bound: Lower bound for labels. If None and network_type + is "lan", defaults to log(1e-10). + shuffle: Whether to shuffle data (default: True). + num_workers: Number of worker processes for data loading (default: 1). + pin_memory: Whether to pin memory for faster GPU transfer (default: True). + + Returns + ------- + torch.utils.data.DataLoader configured for training. + + Example + ------- + >>> file_list = list(Path("data/lan_mlp/ddm").glob("*.pickle")) + >>> train_dl = make_dataloader( + ... file_ids=file_list, + ... batch_size=4096, + ... network_type="lan", + ... ) + """ + # Set sensible defaults based on network type + if label_lower_bound is None and network_type == "lan": + label_lower_bound = np.log(1e-10) + + dataset = DatasetTorch( + file_ids=file_ids, + batch_size=batch_size, + features_key=f"{network_type}_data", + label_key=f"{network_type}_labels", + label_lower_bound=label_lower_bound, + ) + + return DataLoader( + dataset, + shuffle=shuffle, + batch_size=None, + num_workers=num_workers, + pin_memory=pin_memory, + ) + + +def make_train_valid_dataloaders( + file_ids: list[str] | list[Path], + batch_size: int, + network_type: str = "lan", + train_val_split: float = 0.9, + shuffle_files: bool = True, + label_lower_bound: float | None = None, + num_workers: int = 1, + pin_memory: bool = True, +) -> tuple[DataLoader, DataLoader, int]: + """Create train and validation DataLoaders with automatic file splitting. + + This is a convenience function that splits the file list into train/validation + sets and creates DataLoaders for each. + + Arguments + --------- + file_ids: List of paths to training data pickle files. + batch_size: Batch size for training. + network_type: Type of network ("lan", "cpn", or "opn"). + train_val_split: Fraction of files to use for training (default: 0.9). + shuffle_files: Whether to shuffle files before splitting (default: True). + label_lower_bound: Lower bound for labels. If None and network_type + is "lan", defaults to log(1e-10). + num_workers: Number of worker processes for data loading (default: 1). + pin_memory: Whether to pin memory for faster GPU transfer (default: True). + + Returns + ------- + tuple of (train_dataloader, valid_dataloader, input_dim) + + Example + ------- + >>> file_list = list(Path("data/lan_mlp/ddm").glob("*.pickle")) + >>> train_dl, valid_dl, input_dim = make_train_valid_dataloaders( + ... file_ids=file_list, + ... batch_size=4096, + ... network_type="lan", + ... train_val_split=0.9, + ... ) + """ + import random + + file_list = [str(f) for f in file_ids] # Ensure strings for consistency + if shuffle_files: + random.shuffle(file_list) + + split_idx = int(len(file_list) * train_val_split) + train_files = file_list[:split_idx] + valid_files = file_list[split_idx:] + + if len(train_files) == 0: + raise ValueError( + f"No training files after split. Got {len(file_list)} files " + f"with train_val_split={train_val_split}" + ) + if len(valid_files) == 0: + raise ValueError( + f"No validation files after split. Got {len(file_list)} files " + f"with train_val_split={train_val_split}" + ) + + train_dl = make_dataloader( + file_ids=train_files, + batch_size=batch_size, + network_type=network_type, + label_lower_bound=label_lower_bound, + shuffle=True, + num_workers=num_workers, + pin_memory=pin_memory, + ) + + valid_dl = make_dataloader( + file_ids=valid_files, + batch_size=batch_size, + network_type=network_type, + label_lower_bound=label_lower_bound, + shuffle=True, + num_workers=num_workers, + pin_memory=pin_memory, + ) + + return train_dl, valid_dl, train_dl.dataset.input_dim + + +# --- Factory Functions --- + + +def TorchMLPFactory( + network_config: dict | str, + input_dim: int, + network_type: str | None = None, +) -> "TorchMLP": + """Factory function to create a TorchMLP object. + + This provides a consistent API with JaxMLPFactory and handles + loading network configs from pickle files. + + Arguments + --------- + network_config: Dictionary containing the network configuration, + or path to a pickled config file. + input_dim: Input dimension (typically from dataloader.dataset.input_dim). + network_type: Network type ("lan", "cpn", "opn"). If not provided, + will be inferred from train_output_type in network_config. + + Returns + ------- + TorchMLP instance ready for training. + + Example + ------- + >>> train_dl, valid_dl, input_dim = make_train_valid_dataloaders(...) + >>> net = TorchMLPFactory( + ... network_config=network_config, + ... input_dim=input_dim, + ... ) + """ + if isinstance(network_config, str): + with open(network_config, "rb") as f: + network_config = pickle.load(f) + + return TorchMLP( + network_config=network_config, + input_shape=input_dim, + network_type=network_type, + ) + + class TorchMLP(nn.Module): """TorchMLP class. @@ -167,15 +359,11 @@ class TorchMLP(nn.Module): Network type. """ - # AF-TODO: Potentially split this via super-class - # In the end I want 'eval', but differentiable - # w.r.t to input ...., might be a problem def __init__( self, network_config: dict, input_shape: int = 10, network_type: str | None = None, - **kwargs, ) -> None: super(TorchMLP, self).__init__() @@ -649,160 +837,144 @@ def _save_onnx(model: TorchMLP, dev: torch.device, path: str) -> None: logger.info(f"Saving model to ONNX format to: {path}") -class LoadTorchMLPInfer: - """Class to load TorchMLP models for inference. (This - was originally useful directly for application in the - HDDM toolbox). +class LoadTorchMLP: + """General-purpose class to load TorchMLP models. + + Does NOT call eval() by default - suitable for fine-tuning or further training. + For inference with eval() enabled, use LoadTorchMLPInfer instead. Arguments --------- model_file_path (str): - Path to the model file. - network_config (dict): - Network configuration. + Path to the model state dict file. + network_config (dict | str): + Network configuration dictionary or path to a pickled config file. input_dim (int): Input dimension. - + network_type (str | None): + Network type ("lan", "cpn", "opn"). If not provided, will be + inferred from train_output_type in network_config. + inference_mode (bool): + If True, sets network to eval mode. Default is False for general use. + Use LoadTorchMLPInfer for inference with eval() enabled by default. """ def __init__( self, - model_file_path: str | None = None, - network_config: dict | str | None = None, - input_dim: int | None = None, + model_file_path: str, + network_config: dict | str, + input_dim: int, network_type: str | None = None, + inference_mode: bool = False, ) -> None: - if input_dim is None: - raise ValueError("input_dim is required") - if model_file_path is None: - raise ValueError("model_file_path is required") - - self.model_file_path = model_file_path torch.backends.cudnn.benchmark = True self.dev = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) + self.model_file_path = model_file_path + # Load network config from pickle file if string path provided if isinstance(network_config, str): with open(network_config, "rb") as f: self.network_config = pickle.load(f) - elif isinstance(network_config, dict): - self.network_config = network_config else: - raise ValueError("network config is neither a string nor a dictionary") + self.network_config = network_config self.input_dim = input_dim + # Create model self.net = TorchMLP( network_config=self.network_config, input_shape=self.input_dim, - generative_model_id=None, network_type=network_type, ) + + # Load state dict if not torch.cuda.is_available(): self.net.load_state_dict( torch.load(self.model_file_path, map_location=torch.device("cpu")) ) else: self.net.load_state_dict(torch.load(self.model_file_path)) + self.net.to(self.dev) - self.net.eval() + + # Set eval mode if requested + if inference_mode: + self.net.eval() @torch.no_grad() def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the network. + + Arguments + --------- + x: Input tensor. + + Returns + ------- + Output tensor. + """ return self.net(x.to(self.dev)) @torch.no_grad() def predict_on_batch(self, x: np.ndarray | None = None) -> np.ndarray: - """ - Intended as function that computes trial wise log-likelihoods - from a matrix input. - To be used primarily through the HDDM toolbox. + """Make predictions on a batch of data. + + This method is intended for computing trial-wise log-likelihoods + from a matrix input, and is commonly used through the HDDM toolbox. Arguments --------- - x (numpy.ndarray(dtype=numpy.float32)): - Matrix which will be passed through the network. - LANs expect the matrix columns to follow a specific order. - When used in HDDM, x will be passed as follows. - The first few columns are trial wise model parameters - (order specified in the model_config file under the 'params' key). - The last two columns are filled with trial wise - reaction times and choices. - When not used via HDDM, no such restriction applies. - Output - ------ - numpy.ndarray(dtype = numpy.float32): - Output of the network. When called through HDDM, - this is expected as trial-wise log likelihoods - of a given generative model. + x (numpy.ndarray): + Input matrix (dtype should be numpy.float32). + For LANs, columns should follow a specific order: + model parameters followed by reaction times and choices. + Returns + ------- + numpy.ndarray: + Network output as numpy array. """ return self.net(torch.from_numpy(x).to(self.dev)).cpu().numpy() -class LoadTorchMLP: - """Class to load TorchMLP models. +class LoadTorchMLPInfer(LoadTorchMLP): + """Model loader with inference mode enabled by default. + + Calls eval() on the network, suitable for inference/prediction. + For fine-tuning or further training, use LoadTorchMLP instead. + + This class was originally useful directly for application in the + HDDM toolbox. Arguments --------- model_file_path (str): - Path to the model file. - network_config (dict): - Network configuration. + Path to the model state dict file. + network_config (dict | str): + Network configuration dictionary or path to a pickled config file. input_dim (int): - Input dimension.""" + Input dimension. + network_type (str | None): + Network type ("lan", "cpn", "opn"). If not provided, will be + inferred from train_output_type in network_config. + inference_mode (bool): + If True, sets network to eval mode. Default is True for inference. + """ def __init__( self, model_file_path: str, network_config: dict | str, input_dim: int, + network_type: str | None = None, + inference_mode: bool = True, ) -> None: - self.dev = ( - torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - ) - self.model_file_path = model_file_path - - # Load network config from pickle file if string path provided - if isinstance(network_config, str): - with open(network_config, "rb") as f: - self.network_config = pickle.load(f) - else: - self.network_config = network_config - - self.input_dim = input_dim - - self.net = TorchMLP( - network_config=self.network_config, - input_shape=self.input_dim, - generative_model_id=None, + super().__init__( + model_file_path=model_file_path, + network_config=network_config, + input_dim=input_dim, + network_type=network_type, + inference_mode=inference_mode, ) - if not torch.cuda.is_available(): - self.net.load_state_dict( - torch.load(self.model_file_path, map_location=torch.device("cpu")) - ) - else: # pragma: no cover - self.net.load_state_dict(torch.load(self.model_file_path)) - self.net.to(self.dev) - - @torch.no_grad() - def __call__(self, x: torch.Tensor) -> torch.Tensor: - return self.net(x.to(self.dev)) - - @torch.no_grad() - def predict_on_batch(self, x: np.ndarray | None = None) -> np.ndarray: - """Makes predictions on a batch of data. - - Args: - x: Input data as numpy array. - - Returns: - numpy.ndarray: Model predictions as numpy array. - - Note: - Input data is automatically converted to torch tensor and moved to the - appropriate device (CPU/GPU). Output is converted back to numpy array - on CPU. - """ - return self.net(torch.from_numpy(x).to(self.dev)).cpu().numpy() diff --git a/tests/test_end_to_end_jax.py b/tests/test_end_to_end_jax.py index a412ee3..34b3a62 100644 --- a/tests/test_end_to_end_jax.py +++ b/tests/test_end_to_end_jax.py @@ -145,7 +145,7 @@ def test_end_to_end_lan_mlp( pin_memory=True, ) - jax_net = lanfactory.trainers.MLPJaxFactory( + jax_net = lanfactory.trainers.JaxMLPFactory( network_config=network_config, train=True ) @@ -167,7 +167,7 @@ def test_end_to_end_lan_mlp( verbose=1, ) - jax_infer = lanfactory.trainers.MLPJaxFactory( + jax_infer = lanfactory.trainers.JaxMLPFactory( network_config=network_config, train=False, ) diff --git a/tests/test_end_to_end_torch.py b/tests/test_end_to_end_torch.py index 6fb955f..cbd7ff6 100644 --- a/tests/test_end_to_end_torch.py +++ b/tests/test_end_to_end_torch.py @@ -149,7 +149,6 @@ def test_end_to_end_lan_mlp( network_config=network_config, input_shape=torch_training_dataset.input_dim, network_type=train_type, - train=True, ) logger.info(f"torch_net: {torch_net} \n") diff --git a/tests/test_jax_mlp.py b/tests/test_jax_mlp.py index eb9a863..895a1e6 100644 --- a/tests/test_jax_mlp.py +++ b/tests/test_jax_mlp.py @@ -6,26 +6,26 @@ import jax import jax.numpy as jnp -from lanfactory.trainers.jax_mlp import MLPJaxFactory, MLPJax +from lanfactory.trainers.jax_mlp import JaxMLPFactory, JaxMLP def test_mlp_jax_factory_with_dict(): - """Test MLPJaxFactory with dict network_config.""" + """Test JaxMLPFactory with dict network_config.""" network_config = { "layer_sizes": [100, 100, 1], "activations": ["tanh", "tanh", "linear"], "train_output_type": "logprob", } - model = MLPJaxFactory(network_config=network_config, train=True) + model = JaxMLPFactory(network_config=network_config, train=True) - assert isinstance(model, MLPJax) + assert isinstance(model, JaxMLP) assert model.layer_sizes == [100, 100, 1] assert model.activations == ["tanh", "tanh", "linear"] def test_mlp_jax_factory_with_string_path(tmp_path): - """Test MLPJaxFactory with string path to network_config.""" + """Test JaxMLPFactory with string path to network_config.""" network_config = { "layer_sizes": [100, 100, 1], "activations": ["tanh", "tanh", "linear"], @@ -38,21 +38,21 @@ def test_mlp_jax_factory_with_string_path(tmp_path): pickle.dump(network_config, f) # Load using string path - model = MLPJaxFactory(network_config=str(config_file), train=True) + model = JaxMLPFactory(network_config=str(config_file), train=True) - assert isinstance(model, MLPJax) + assert isinstance(model, JaxMLP) assert model.layer_sizes == [100, 100, 1] def test_mlp_jax_factory_raises_value_error(): - """Test MLPJaxFactory raises ValueError for invalid network_config type.""" + """Test JaxMLPFactory raises ValueError for invalid network_config type.""" with pytest.raises(ValueError, match="network_config argument is not passed"): - MLPJaxFactory(network_config=123, train=True) # Invalid type + JaxMLPFactory(network_config=123, train=True) # Invalid type def test_mlp_jax_class_initialization(): - """Test MLPJax class initialization.""" - model = MLPJax( + """Test JaxMLP class initialization.""" + model = JaxMLP( layer_sizes=[100, 100, 1], activations=["tanh", "tanh", "linear"], train_output_type="logprob", @@ -66,8 +66,8 @@ def test_mlp_jax_class_initialization(): def test_mlp_jax_forward_pass(): - """Test MLPJax forward pass.""" - model = MLPJax( + """Test JaxMLP forward pass.""" + model = JaxMLP( layer_sizes=[10, 10, 1], activations=["tanh", "tanh", "linear"], train_output_type="logprob", @@ -87,8 +87,8 @@ def test_mlp_jax_forward_pass(): def test_mlp_jax_with_different_activations(): - """Test MLPJax with different activation functions.""" - model = MLPJax( + """Test JaxMLP with different activation functions.""" + model = JaxMLP( layer_sizes=[10, 10, 1], activations=["relu", "sigmoid", "linear"], train_output_type="logits", @@ -100,22 +100,22 @@ def test_mlp_jax_with_different_activations(): def test_mlp_jax_factory_train_false(): - """Test MLPJaxFactory with train=False.""" + """Test JaxMLPFactory with train=False.""" network_config = { "layer_sizes": [100, 100, 1], "activations": ["tanh", "tanh", "linear"], "train_output_type": "logprob", } - model = MLPJaxFactory(network_config=network_config, train=False) + model = JaxMLPFactory(network_config=network_config, train=False) - assert isinstance(model, MLPJax) + assert isinstance(model, JaxMLP) assert model.train is False def test_mlp_jax_forward_with_non_linear_output_activation(): - """Test MLPJax forward pass with non-linear output activation.""" - from lanfactory.trainers.jax_mlp import MLPJaxFactory + """Test JaxMLP forward pass with non-linear output activation.""" + from lanfactory.trainers.jax_mlp import JaxMLPFactory network_config = { "layer_sizes": [10, 10, 1], @@ -123,7 +123,7 @@ def test_mlp_jax_forward_with_non_linear_output_activation(): "train_output_type": "logprob", } - model = MLPJaxFactory(network_config=network_config, train=True) + model = JaxMLPFactory(network_config=network_config, train=True) key = jax.random.PRNGKey(0) x = jax.random.normal(key, (5, 5)) @@ -138,8 +138,8 @@ def test_mlp_jax_forward_with_non_linear_output_activation(): def test_mlp_jax_inference_mode_with_logits(): - """Test MLPJax forward pass in inference mode with logits output.""" - from lanfactory.trainers.jax_mlp import MLPJaxFactory + """Test JaxMLP forward pass in inference mode with logits output.""" + from lanfactory.trainers.jax_mlp import JaxMLPFactory network_config = { "layer_sizes": [10, 10, 1], @@ -147,7 +147,7 @@ def test_mlp_jax_inference_mode_with_logits(): "train_output_type": "logits", } - model = MLPJaxFactory(network_config=network_config, train=False) + model = JaxMLPFactory(network_config=network_config, train=False) key = jax.random.PRNGKey(0) x = jax.random.normal(key, (5, 5)) @@ -162,8 +162,8 @@ def test_mlp_jax_inference_mode_with_logits(): def test_mlp_jax_load_state_from_file_error(): - """Test MLPJax load_state_from_file raises error when file_path is None.""" - from lanfactory.trainers.jax_mlp import MLPJaxFactory + """Test JaxMLP load_state_from_file raises error when file_path is None.""" + from lanfactory.trainers.jax_mlp import JaxMLPFactory import pytest network_config = { @@ -172,15 +172,15 @@ def test_mlp_jax_load_state_from_file_error(): "train_output_type": "logprob", } - model = MLPJaxFactory(network_config=network_config, train=True) + model = JaxMLPFactory(network_config=network_config, train=True) with pytest.raises(ValueError, match="file_path argument needs to be specified"): model.load_state_from_file(seed=42, input_dim=5, file_path=None) def test_mlp_jax_load_state_from_file_without_input_dim(tmp_path): - """Test MLPJax load_state_from_file without providing input_dim.""" - from lanfactory.trainers.jax_mlp import MLPJaxFactory + """Test JaxMLP load_state_from_file without providing input_dim.""" + from lanfactory.trainers.jax_mlp import JaxMLPFactory import flax.serialization network_config = { @@ -189,7 +189,7 @@ def test_mlp_jax_load_state_from_file_without_input_dim(tmp_path): "train_output_type": "logprob", } - model = MLPJaxFactory(network_config=network_config, train=True) + model = JaxMLPFactory(network_config=network_config, train=True) key = jax.random.PRNGKey(0) x = jax.random.normal(key, (5, 5)) @@ -210,8 +210,8 @@ def test_mlp_jax_load_state_from_file_without_input_dim(tmp_path): def test_mlp_jax_make_forward_partial_with_dict_state(tmp_path): - """Test MLPJax make_forward_partial with dict state.""" - from lanfactory.trainers.jax_mlp import MLPJaxFactory + """Test JaxMLP make_forward_partial with dict state.""" + from lanfactory.trainers.jax_mlp import JaxMLPFactory network_config = { "layer_sizes": [10, 10, 1], @@ -219,7 +219,7 @@ def test_mlp_jax_make_forward_partial_with_dict_state(tmp_path): "train_output_type": "logprob", } - model = MLPJaxFactory(network_config=network_config, train=True) + model = JaxMLPFactory(network_config=network_config, train=True) key = jax.random.PRNGKey(0) x = jax.random.normal(key, (5, 5)) @@ -241,8 +241,8 @@ def test_mlp_jax_make_forward_partial_with_dict_state(tmp_path): def test_mlp_jax_make_forward_partial_without_jit(tmp_path): - """Test MLPJax make_forward_partial without JIT compilation.""" - from lanfactory.trainers.jax_mlp import MLPJaxFactory + """Test JaxMLP make_forward_partial without JIT compilation.""" + from lanfactory.trainers.jax_mlp import JaxMLPFactory import flax.serialization network_config = { @@ -251,7 +251,7 @@ def test_mlp_jax_make_forward_partial_without_jit(tmp_path): "train_output_type": "logprob", } - model = MLPJaxFactory(network_config=network_config, train=True) + model = JaxMLPFactory(network_config=network_config, train=True) key = jax.random.PRNGKey(0) x = jax.random.normal(key, (5, 5)) @@ -275,8 +275,8 @@ def test_mlp_jax_make_forward_partial_without_jit(tmp_path): def test_mlp_jax_make_forward_partial_invalid_state_type(): - """Test MLPJax make_forward_partial raises error with invalid state type.""" - from lanfactory.trainers.jax_mlp import MLPJaxFactory + """Test JaxMLP make_forward_partial raises error with invalid state type.""" + from lanfactory.trainers.jax_mlp import JaxMLPFactory import pytest network_config = { @@ -285,7 +285,7 @@ def test_mlp_jax_make_forward_partial_invalid_state_type(): "train_output_type": "logprob", } - model = MLPJaxFactory(network_config=network_config, train=True) + model = JaxMLPFactory(network_config=network_config, train=True) # Test with invalid state type (list instead of dict or string) with pytest.raises( @@ -297,7 +297,7 @@ def test_mlp_jax_make_forward_partial_invalid_state_type(): def test_mlp_jax_with_logits_inference(): - """Test MLPJax forward pass with logits output in inference mode.""" + """Test JaxMLP forward pass with logits output in inference mode.""" import jax.numpy as jnp network_config = { @@ -306,7 +306,7 @@ def test_mlp_jax_with_logits_inference(): "train_output_type": "logits", } - model = MLPJaxFactory(network_config=network_config, train=False) + model = JaxMLPFactory(network_config=network_config, train=False) # Create dummy input rng = jax.random.PRNGKey(0) @@ -323,7 +323,7 @@ def test_mlp_jax_with_logits_inference(): def test_mlp_jax_with_non_linear_output_activation(): - """Test MLPJax with non-linear output activation.""" + """Test JaxMLP with non-linear output activation.""" import jax.numpy as jnp network_config = { @@ -332,7 +332,7 @@ def test_mlp_jax_with_non_linear_output_activation(): "train_output_type": "logprob", } - model = MLPJaxFactory(network_config=network_config, train=True) + model = JaxMLPFactory(network_config=network_config, train=True) # Create dummy input rng = jax.random.PRNGKey(0) diff --git a/tests/test_mlflow_integration.py b/tests/test_mlflow_integration.py index 8812aa0..bedc7f3 100644 --- a/tests/test_mlflow_integration.py +++ b/tests/test_mlflow_integration.py @@ -307,7 +307,7 @@ def test_jax_trainer_mlflow_logging( ): """Test that JAX trainer logs to MLflow correctly.""" from torch.utils.data import DataLoader - from lanfactory.trainers.jax_mlp import MLPJaxFactory, ModelTrainerJaxMLP + from lanfactory.trainers.jax_mlp import JaxMLPFactory, ModelTrainerJaxMLP from lanfactory.trainers.torch_mlp import DatasetTorch tracking_uri = test_mlflow_dir["tracking_uri"] @@ -354,7 +354,7 @@ def test_jax_trainer_mlflow_logging( dataloader_val = DataLoader(train_dataset, shuffle=False, batch_size=None) # Create model and trainer - net = MLPJaxFactory(network_config=network_config, train=True) + net = JaxMLPFactory(network_config=network_config, train=True) trainer = ModelTrainerJaxMLP( train_config=train_config, diff --git a/tests/test_torch_mlp.py b/tests/test_torch_mlp.py index d230d74..18d93d9 100644 --- a/tests/test_torch_mlp.py +++ b/tests/test_torch_mlp.py @@ -412,9 +412,7 @@ def test_model_trainer_torch_mlp_init_with_dict(): } # Create mock model and dataloaders - mock_model = TorchMLP( - network_config=network_config, input_shape=6, generative_model_id=None - ) + mock_model = TorchMLP(network_config=network_config, input_shape=6) mock_train_dl = MagicMock() mock_valid_dl = MagicMock() @@ -453,9 +451,7 @@ def test_model_trainer_torch_mlp_init_with_path(tmp_path): pickle.dump(train_config, f) # Create mock model and dataloaders - mock_model = TorchMLP( - network_config=network_config, input_shape=6, generative_model_id=None - ) + mock_model = TorchMLP(network_config=network_config, input_shape=6) mock_train_dl = MagicMock() mock_valid_dl = MagicMock() @@ -479,9 +475,7 @@ def test_load_torch_mlp_infer_with_dict_config(tmp_path): } # Create a simple model and save its state dict - model = TorchMLP( - network_config=network_config, input_shape=6, generative_model_id=None - ) + model = TorchMLP(network_config=network_config, input_shape=6) model_file = tmp_path / "model.pt" torch.save(model.state_dict(), model_file) @@ -510,9 +504,7 @@ def test_load_torch_mlp_infer_with_string_config(tmp_path): pickle.dump(network_config, f) # Create a simple model and save its state dict - model = TorchMLP( - network_config=network_config, input_shape=6, generative_model_id=None - ) + model = TorchMLP(network_config=network_config, input_shape=6) model_file = tmp_path / "model.pt" torch.save(model.state_dict(), model_file) @@ -536,9 +528,7 @@ def test_load_torch_mlp_infer_call_method(tmp_path): } # Create a simple model and save its state dict - model = TorchMLP( - network_config=network_config, input_shape=6, generative_model_id=None - ) + model = TorchMLP(network_config=network_config, input_shape=6) model_file = tmp_path / "model.pt" torch.save(model.state_dict(), model_file) @@ -566,9 +556,7 @@ def test_load_torch_mlp_infer_predict_on_batch(tmp_path): } # Create a simple model and save its state dict - model = TorchMLP( - network_config=network_config, input_shape=6, generative_model_id=None - ) + model = TorchMLP(network_config=network_config, input_shape=6) model_file = tmp_path / "model.pt" torch.save(model.state_dict(), model_file) diff --git a/tests/test_transform_onnx.py b/tests/test_transform_onnx.py index cdb169d..ec663e2 100644 --- a/tests/test_transform_onnx.py +++ b/tests/test_transform_onnx.py @@ -98,7 +98,6 @@ def test_transform_to_onnx_loads_network_config(mock_network_config): call_kwargs = mock_torch_mlp.call_args[1] assert call_kwargs["network_config"] == mock_network_config assert call_kwargs["input_shape"] == 6 - assert call_kwargs["generative_model_id"] is None def test_transform_to_onnx_loads_state_dict(mock_network_config, mock_state_dict): From 4222800f0bf1335950c80e1a98a932c892f76055 Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 16 Mar 2026 21:38:28 -0400 Subject: [PATCH 02/11] next round of improvements --- docs/api/hf.md | 1 + docs/using_huggingface.md | 199 +++++++++++++++++++++ mkdocs.yml | 2 + pyproject.toml | 6 + src/lanfactory/cli/download_hf.py | 157 ++++++++++++++++ src/lanfactory/cli/upload_hf.py | 178 ++++++++++++++++++ src/lanfactory/hf/__init__.py | 25 +++ src/lanfactory/hf/download.py | 173 ++++++++++++++++++ src/lanfactory/hf/model_card.py | 288 ++++++++++++++++++++++++++++++ src/lanfactory/hf/upload.py | 233 ++++++++++++++++++++++++ tests/cli/test_hf_cli.py | 166 +++++++++++++++++ tests/hf/__init__.py | 1 + tests/hf/test_download.py | 166 +++++++++++++++++ tests/hf/test_model_card.py | 211 ++++++++++++++++++++++ tests/hf/test_upload.py | 186 +++++++++++++++++++ 15 files changed, 1992 insertions(+) create mode 100644 docs/api/hf.md create mode 100644 docs/using_huggingface.md create mode 100644 src/lanfactory/cli/download_hf.py create mode 100644 src/lanfactory/cli/upload_hf.py create mode 100644 src/lanfactory/hf/__init__.py create mode 100644 src/lanfactory/hf/download.py create mode 100644 src/lanfactory/hf/model_card.py create mode 100644 src/lanfactory/hf/upload.py create mode 100644 tests/cli/test_hf_cli.py create mode 100644 tests/hf/__init__.py create mode 100644 tests/hf/test_download.py create mode 100644 tests/hf/test_model_card.py create mode 100644 tests/hf/test_upload.py diff --git a/docs/api/hf.md b/docs/api/hf.md new file mode 100644 index 0000000..4aaaf1b --- /dev/null +++ b/docs/api/hf.md @@ -0,0 +1 @@ +:::lanfactory.hf diff --git a/docs/using_huggingface.md b/docs/using_huggingface.md new file mode 100644 index 0000000..85c56e5 --- /dev/null +++ b/docs/using_huggingface.md @@ -0,0 +1,199 @@ +# Using HuggingFace Hub + +LANfactory provides CLI commands for uploading trained models to and downloading models from HuggingFace Hub. + +## Installation + +HuggingFace support requires the optional `hf` dependencies: + +```bash +pip install lanfactory[hf] +``` + +Or install all optional dependencies: + +```bash +pip install lanfactory[all] +``` + +## Authentication + +Before uploading, authenticate with HuggingFace: + +```bash +# Option 1: Login interactively +huggingface-cli login + +# Option 2: Set environment variable +export HF_TOKEN="your_token_here" + +# Option 3: Pass token via CLI +upload-hf ... --token "your_token_here" +``` + +## Uploading Models + +### 1. Create a `model_card.yaml` file + +In your trained model folder, create a `model_card.yaml` file with model metadata: + +```yaml +# Required metadata (HuggingFace frontmatter) +tags: + - lan + - ssm + - ddm + - hssm +library_name: onnx +license: mit + +# Model information +title: "LAN Model for DDM" +description: "Likelihood Approximation Network trained on DDM (Drift Diffusion Model) simulations." + +# Optional: Network architecture (auto-extracted from config.pickle if not provided) +architecture: + layer_sizes: [100, 100, 1] + activations: [tanh, tanh, linear] + network_type: lan + +# Optional: Training details +training: + epochs: 20 + optimizer: adam + learning_rate: 0.001 + +# Usage example (shown in README) +usage_example: | + import hssm + model = hssm.HSSM(data=my_data, model="ddm", loglik_kind="approx_differentiable") +``` + +### 2. Upload using the CLI + +```bash +upload-hf \ + --model-folder ./networks/lan/ddm/ \ + --network-type lan \ + --model-name ddm \ + --commit-message "Initial upload" +``` + +This uploads to `franklab/HSSM` (default) at path `lan/ddm/`. + +### CLI Options + +| Option | Required | Description | +|--------|----------|-------------| +| `--model-folder` | Yes | Path to folder with trained model artifacts | +| `--network-type` | Yes | Network type: `lan`, `cpn`, or `opn` | +| `--model-name` | Yes | Model name (e.g., `ddm`, `angle`) | +| `--repo-id` | No | HuggingFace repo ID (default: `franklab/HSSM`) | +| `--commit-message` | No | Git commit message (default: "Upload model") | +| `--private` | No | Create a private repository | +| `--create-repo` | No | Create repository if it doesn't exist | +| `--include-patterns` | No | Comma-separated glob patterns to include | +| `--exclude-patterns` | No | Comma-separated glob patterns to exclude | +| `--revision` | No | Branch or tag name for versioning | +| `--token` | No | HuggingFace API token | +| `--dry-run` | No | Show what would be uploaded without uploading | + +### Dry Run + +To preview what will be uploaded without actually uploading: + +```bash +upload-hf \ + --model-folder ./networks/lan/ddm/ \ + --network-type lan \ + --model-name ddm \ + --dry-run +``` + +## Downloading Models + +### Download using the CLI + +```bash +download-hf \ + --network-type lan \ + --model-name ddm \ + --output-folder ./models/ddm/ +``` + +This downloads from `franklab/HSSM` at path `lan/ddm/`. + +### CLI Options + +| Option | Required | Description | +|--------|----------|-------------| +| `--network-type` | Yes | Network type: `lan`, `cpn`, or `opn` | +| `--model-name` | Yes | Model name (e.g., `ddm`, `angle`) | +| `--output-folder` | Yes | Local destination folder | +| `--repo-id` | No | HuggingFace repo ID (default: `franklab/HSSM`) | +| `--revision` | No | Branch, tag, or commit to download (default: main) | +| `--include-patterns` | No | Comma-separated glob patterns to include | +| `--exclude-patterns` | No | Comma-separated glob patterns to exclude | +| `--token` | No | HuggingFace API token (for private repos) | +| `--force` | No | Overwrite existing files | + +## Repository Structure + +Models are organized in the repository using the following structure: + +``` +franklab/HSSM/ +├── lan/ +│ ├── ddm/ +│ │ ├── model.onnx +│ │ ├── network_config.pickle +│ │ ├── train_config.pickle +│ │ └── README.md +│ ├── angle/ +│ │ └── ... +│ └── weibull/ +│ └── ... +├── cpn/ +│ └── ... +└── opn/ + └── ... +``` + +## Using Downloaded Models with HSSM + +After downloading a model, you can use it with HSSM: + +```python +import hssm + +# HSSM will look for models in the franklab/HSSM repository +model = hssm.HSSM( + data=my_data, + model="ddm", + loglik_kind="approx_differentiable" +) +``` + +## Programmatic Usage + +You can also use the upload/download functions directly in Python: + +```python +from pathlib import Path +from lanfactory.hf import upload_model, download_model + +# Upload +upload_model( + model_folder=Path("./networks/lan/ddm/"), + network_type="lan", + model_name="ddm", + commit_message="v1.0.0 release", +) + +# Download +download_model( + network_type="lan", + model_name="ddm", + output_folder=Path("./models/ddm/"), +) +``` diff --git a/mkdocs.yml b/mkdocs.yml index a1aed11..3e77ce1 100755 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -13,9 +13,11 @@ nav: - OPN (PyTorch): basic_tutorial/basic_tutorial_opn_torch.ipynb - Guides: - MLflow Integration: using_mlflow.md + - HuggingFace Hub: using_huggingface.md - API: - lanfactory: api/lanfactory.md - config: api/config.md + - hf: api/hf.md - onnx: api/onnx.md - trainers: api/trainers.md - utils: api/utils.md diff --git a/pyproject.toml b/pyproject.toml index 5709a81..3f353f5 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,8 @@ keywords = [ [project.optional-dependencies] mlflow = ["mlflow>=3.6.0"] +hf = ["huggingface-hub>=0.20.0"] +all = ["mlflow>=3.6.0", "huggingface-hub>=0.20.0"] [dependency-groups] dev = [ @@ -109,6 +111,8 @@ omit = [ "*/conftest.py", "*/cli/jax_train.py", # CLI entry points are tested via smoke tests "*/cli/torch_train.py", # CLI entry points are tested via smoke tests + "*/cli/upload_hf.py", # CLI entry points are tested via smoke tests + "*/cli/download_hf.py", # CLI entry points are tested via smoke tests ] [tool.coverage.report] @@ -133,6 +137,8 @@ line-length = 120 jaxtrain = "lanfactory.cli.jax_train:app" torchtrain = "lanfactory.cli.torch_train:app" transform-onnx = "lanfactory.onnx.transform_onnx:app" +upload-hf = "lanfactory.cli.upload_hf:app" +download-hf = "lanfactory.cli.download_hf:app" [tool.setuptools.package-data] "lanfactory.cli" = ["config_network_training_lan.yaml"] diff --git a/src/lanfactory/cli/download_hf.py b/src/lanfactory/cli/download_hf.py new file mode 100644 index 0000000..1540f1a --- /dev/null +++ b/src/lanfactory/cli/download_hf.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python +"""Command-line interface for downloading models from HuggingFace Hub. + +This module provides a CLI tool for downloading LANfactory models +from HuggingFace Hub. + +Usage: + download-hf --network-type lan --model-name ddm --output-folder ./models/ddm/ +""" + +import logging +from pathlib import Path + +import typer + +app = typer.Typer() + +# Default repository for official HSSM models +DEFAULT_REPO_ID = "franklab/HSSM" + + +@app.command() +def main( + network_type: str = typer.Option( + ..., + "--network-type", + help="Network type: lan, cpn, or opn.", + ), + model_name: str = typer.Option( + ..., + "--model-name", + help="Model name (e.g., ddm, angle).", + ), + output_folder: Path = typer.Option( + ..., + "--output-folder", + help="Local destination folder.", + file_okay=False, + dir_okay=True, + resolve_path=True, + ), + repo_id: str = typer.Option( + DEFAULT_REPO_ID, + "--repo-id", + help=f"HuggingFace repository ID (default: {DEFAULT_REPO_ID}).", + ), + revision: str = typer.Option( + None, + "--revision", + help="Specific branch, tag, or commit to download (default: main).", + ), + include_patterns: str = typer.Option( + None, + "--include-patterns", + help="Comma-separated glob patterns for files to include.", + ), + exclude_patterns: str = typer.Option( + None, + "--exclude-patterns", + help="Comma-separated glob patterns for files to exclude.", + ), + token: str = typer.Option( + None, + "--token", + envvar="HF_TOKEN", + help="HuggingFace API token for private repos (defaults to HF_TOKEN env var).", + ), + force: bool = typer.Option( + False, + "--force", + help="Overwrite existing files.", + is_flag=True, + ), + log_level: str = typer.Option( + "WARNING", + "--log-level", + "-l", + help="Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL).", + case_sensitive=False, + ), +): + """Download a LANfactory model from HuggingFace Hub. + + This command downloads model artifacts from a HuggingFace repository at the path + {network_type}/{model_name}/ (e.g., lan/ddm/). + + Example: + download-hf --network-type lan --model-name ddm --output-folder ./models/ddm/ + + This downloads from franklab/HSSM at path lan/ddm/ by default. + """ + # Set up logging + logging.basicConfig( + level=log_level.upper(), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logger = logging.getLogger(__name__) + + # Validate network_type + valid_network_types = ["lan", "cpn", "opn"] + if network_type not in valid_network_types: + raise typer.BadParameter( + f"network_type must be one of {valid_network_types}, got: {network_type}" + ) + + # Parse patterns + include_list = None + if include_patterns: + include_list = [p.strip() for p in include_patterns.split(",")] + + exclude_list = None + if exclude_patterns: + exclude_list = [p.strip() for p in exclude_patterns.split(",")] + + # Import here to provide better error message if huggingface_hub not installed + try: + from lanfactory.hf import download_model + except ImportError as e: + logger.error( + "huggingface_hub is required for HuggingFace downloads. " + "Install it with: pip install lanfactory[hf]" + ) + raise typer.Exit(code=1) from e + + # Show download source + path_in_repo = f"{network_type}/{model_name}" + typer.echo(f"Download source: {repo_id}/{path_in_repo}") + typer.echo(f"Output folder: {output_folder}") + + try: + result_path = download_model( + network_type=network_type, + model_name=model_name, + output_folder=output_folder, + repo_id=repo_id, + revision=revision, + include_patterns=include_list, + exclude_patterns=exclude_list, + token=token, + force=force, + ) + + typer.echo(f"\nModel downloaded to: {result_path}") + + except FileNotFoundError as e: + logger.error(str(e)) + raise typer.Exit(code=1) from e + except FileExistsError as e: + logger.error(str(e)) + raise typer.Exit(code=1) from e + except Exception as e: + logger.error("Download failed: %s", e) + raise typer.Exit(code=1) from e + + +if __name__ == "__main__": + app() diff --git a/src/lanfactory/cli/upload_hf.py b/src/lanfactory/cli/upload_hf.py new file mode 100644 index 0000000..b61f4d8 --- /dev/null +++ b/src/lanfactory/cli/upload_hf.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python +"""Command-line interface for uploading models to HuggingFace Hub. + +This module provides a CLI tool for uploading trained LANfactory models +to HuggingFace Hub with proper organization and metadata. + +Usage: + upload-hf --model-folder ./networks/lan/ddm/ --network-type lan --model-name ddm +""" + +import logging +from pathlib import Path + +import typer + +app = typer.Typer() + +# Default repository for official HSSM models +DEFAULT_REPO_ID = "franklab/HSSM" + + +@app.command() +def main( + model_folder: Path = typer.Option( + ..., + "--model-folder", + help="Path to the folder containing trained model artifacts (should contain model_card.yaml).", + exists=True, + file_okay=False, + dir_okay=True, + resolve_path=True, + ), + network_type: str = typer.Option( + ..., + "--network-type", + help="Network type: lan, cpn, or opn.", + ), + model_name: str = typer.Option( + ..., + "--model-name", + help="Model name (e.g., ddm, angle).", + ), + repo_id: str = typer.Option( + DEFAULT_REPO_ID, + "--repo-id", + help=f"HuggingFace repository ID (default: {DEFAULT_REPO_ID}).", + ), + commit_message: str = typer.Option( + "Upload model", + "--commit-message", + help="Git commit message for the upload.", + ), + private: bool = typer.Option( + False, + "--private", + help="Create a private repository.", + is_flag=True, + ), + create_repo: bool = typer.Option( + False, + "--create-repo", + help="Create the repository if it doesn't exist.", + is_flag=True, + ), + include_patterns: str = typer.Option( + None, + "--include-patterns", + help="Comma-separated glob patterns for files to include.", + ), + exclude_patterns: str = typer.Option( + None, + "--exclude-patterns", + help="Comma-separated glob patterns for files to exclude.", + ), + revision: str = typer.Option( + None, + "--revision", + help="Branch or tag name for versioning.", + ), + token: str = typer.Option( + None, + "--token", + envvar="HF_TOKEN", + help="HuggingFace API token (defaults to HF_TOKEN env var).", + ), + dry_run: bool = typer.Option( + False, + "--dry-run", + help="Show what would be uploaded without uploading.", + is_flag=True, + ), + log_level: str = typer.Option( + "WARNING", + "--log-level", + "-l", + help="Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL).", + case_sensitive=False, + ), +): + """Upload a trained LANfactory model to HuggingFace Hub. + + This command uploads model artifacts to a HuggingFace repository at the path + {network_type}/{model_name}/ (e.g., lan/ddm/). + + The model folder must contain a model_card.yaml file with model metadata. + This YAML file is converted to a README.md for HuggingFace. + + Example: + upload-hf --model-folder ./networks/lan/ddm/ --network-type lan --model-name ddm + + This uploads to franklab/HSSM at path lan/ddm/ by default. + """ + # Set up logging + logging.basicConfig( + level=log_level.upper(), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logger = logging.getLogger(__name__) + + # Validate network_type + valid_network_types = ["lan", "cpn", "opn"] + if network_type not in valid_network_types: + raise typer.BadParameter( + f"network_type must be one of {valid_network_types}, got: {network_type}" + ) + + # Parse patterns + include_list = None + if include_patterns: + include_list = [p.strip() for p in include_patterns.split(",")] + + exclude_list = None + if exclude_patterns: + exclude_list = [p.strip() for p in exclude_patterns.split(",")] + + # Import here to provide better error message if huggingface_hub not installed + try: + from lanfactory.hf import upload_model + except ImportError as e: + logger.error( + "huggingface_hub is required for HuggingFace uploads. " + "Install it with: pip install lanfactory[hf]" + ) + raise typer.Exit(code=1) from e + + # Show upload destination + path_in_repo = f"{network_type}/{model_name}" + typer.echo(f"Upload destination: {repo_id}/{path_in_repo}") + + try: + url = upload_model( + model_folder=model_folder, + network_type=network_type, + model_name=model_name, + repo_id=repo_id, + commit_message=commit_message, + private=private, + create_repo=create_repo, + include_patterns=include_list, + exclude_patterns=exclude_list, + revision=revision, + token=token, + dry_run=dry_run, + ) + + if url and not dry_run: + typer.echo(f"\nView your model at: {url}") + + except FileNotFoundError as e: + logger.error(str(e)) + raise typer.Exit(code=1) from e + except Exception as e: + logger.error("Upload failed: %s", e) + raise typer.Exit(code=1) from e + + +if __name__ == "__main__": + app() diff --git a/src/lanfactory/hf/__init__.py b/src/lanfactory/hf/__init__.py new file mode 100644 index 0000000..ef4e76b --- /dev/null +++ b/src/lanfactory/hf/__init__.py @@ -0,0 +1,25 @@ +"""HuggingFace Hub integration for LANfactory. + +This module provides utilities for uploading trained models to and +downloading models from HuggingFace Hub. +""" + +from lanfactory.hf.model_card import ( + load_model_card_yaml, + generate_readme, + ModelCardConfig, +) +from lanfactory.hf.upload import upload_model +from lanfactory.hf.download import download_model + +# Default repository for official HSSM models +DEFAULT_REPO_ID = "franklab/HSSM" + +__all__ = [ + "DEFAULT_REPO_ID", + "load_model_card_yaml", + "generate_readme", + "ModelCardConfig", + "upload_model", + "download_model", +] diff --git a/src/lanfactory/hf/download.py b/src/lanfactory/hf/download.py new file mode 100644 index 0000000..d0c9151 --- /dev/null +++ b/src/lanfactory/hf/download.py @@ -0,0 +1,173 @@ +"""Download utilities for HuggingFace Hub. + +This module provides functions to download LANfactory models +from HuggingFace Hub. +""" + +import logging +import shutil +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Default repository for official HSSM models +DEFAULT_REPO_ID = "franklab/HSSM" + + +def download_model( + network_type: str, + model_name: str, + output_folder: Path, + repo_id: str = DEFAULT_REPO_ID, + revision: str | None = None, + include_patterns: list[str] | None = None, + exclude_patterns: list[str] | None = None, + token: str | None = None, + force: bool = False, +) -> Path: + """Download a model from HuggingFace Hub. + + Parameters + ---------- + network_type : str + Network type (e.g., "lan", "cpn", "opn"). + model_name : str + Model name (e.g., "ddm", "angle"). + output_folder : Path + Local destination folder. + repo_id : str + HuggingFace repository ID (default: "franklab/HSSM"). + revision : str | None + Specific branch/tag/commit to download (default: main). + include_patterns : list[str] | None + Glob patterns for files to include. + exclude_patterns : list[str] | None + Glob patterns for files to exclude. + token : str | None + HuggingFace API token for private repos. + force : bool + Whether to overwrite existing files. + + Returns + ------- + Path + Path to the downloaded model folder. + + Raises + ------ + ImportError + If huggingface_hub is not installed. + ValueError + If network_type is not valid. + FileExistsError + If output_folder exists and force is False. + """ + try: + from huggingface_hub import hf_hub_download, list_repo_files + except ImportError: + raise ImportError( + "huggingface_hub is required for HuggingFace downloads. " + "Install it with: pip install lanfactory[hf]" + ) + + # Validate inputs + valid_network_types = ["lan", "cpn", "opn"] + if network_type not in valid_network_types: + raise ValueError( + f"network_type must be one of {valid_network_types}, got: {network_type}" + ) + + output_folder = Path(output_folder) + + # Check if output folder exists + if output_folder.exists() and not force: + raise FileExistsError( + f"Output folder already exists: {output_folder}. Use --force to overwrite." + ) + + # Create output folder + output_folder.mkdir(parents=True, exist_ok=True) + + # Build the path prefix for this model + path_prefix = f"{network_type}/{model_name}/" + + # List files in the repository + try: + all_files = list_repo_files( + repo_id=repo_id, + revision=revision, + token=token, + ) + except Exception as e: + logger.error(f"Failed to list repository files: {e}") + raise + + # Filter files by path prefix + model_files = [f for f in all_files if f.startswith(path_prefix)] + + if not model_files: + raise FileNotFoundError( + f"No files found at {repo_id}/{path_prefix}. " + f"Available paths: {set(f.split('/')[0] for f in all_files if '/' in f)}" + ) + + # Apply include/exclude patterns + if include_patterns: + filtered_files = [] + for f in model_files: + filename = Path(f).name + for pattern in include_patterns: + if Path(filename).match(pattern): + filtered_files.append(f) + break + model_files = filtered_files + + if exclude_patterns: + filtered_files = [] + for f in model_files: + filename = Path(f).name + excluded = False + for pattern in exclude_patterns: + if Path(filename).match(pattern): + excluded = True + break + if not excluded: + filtered_files.append(f) + model_files = filtered_files + + if not model_files: + raise FileNotFoundError( + f"No files matching patterns found at {repo_id}/{path_prefix}" + ) + + logger.info(f"Downloading {len(model_files)} files from {repo_id}/{path_prefix}") + + # Download each file + downloaded_files = [] + for file_path in model_files: + filename = Path(file_path).name + logger.info(f" Downloading: {filename}") + + try: + local_path = hf_hub_download( + repo_id=repo_id, + filename=file_path, + revision=revision, + token=token, + ) + + # Copy to output folder + dest_path = output_folder / filename + shutil.copy2(local_path, dest_path) + downloaded_files.append(dest_path) + + except Exception as e: + logger.error(f"Failed to download {file_path}: {e}") + raise + + logger.info(f"Downloaded {len(downloaded_files)} files to {output_folder}") + print("\nDownload successful!") + print(f"Model saved to: {output_folder}") + print(f"Files downloaded: {len(downloaded_files)}") + + return output_folder diff --git a/src/lanfactory/hf/model_card.py b/src/lanfactory/hf/model_card.py new file mode 100644 index 0000000..55bdf3e --- /dev/null +++ b/src/lanfactory/hf/model_card.py @@ -0,0 +1,288 @@ +"""Model card utilities for HuggingFace Hub. + +This module reads user-provided model_card.yaml files and generates +HuggingFace-compatible README.md files with proper frontmatter. +""" + +import logging +import pickle +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import yaml + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelCardConfig: + """Configuration for model card generation. + + Attributes + ---------- + tags : list[str] + HuggingFace tags for discoverability. + library_name : str + Library name for HuggingFace (default: "onnx"). + license : str + License identifier (default: "mit"). + title : str + Model title. + description : str + Model description. + architecture : dict | None + Network architecture details. + training : dict | None + Training configuration details. + usage_example : str | None + Usage example code. + """ + + tags: list[str] = field(default_factory=lambda: ["lan", "ssm", "hssm"]) + library_name: str = "onnx" + license: str = "mit" + title: str = "LAN Model" + description: str = "Likelihood Approximation Network trained with LANfactory." + architecture: dict | None = None + training: dict | None = None + usage_example: str | None = None + + +def load_model_card_yaml(model_folder: Path) -> ModelCardConfig: + """Load model card configuration from YAML file. + + Parameters + ---------- + model_folder : Path + Path to the model folder containing model_card.yaml. + + Returns + ------- + ModelCardConfig + Parsed model card configuration. + + Raises + ------ + FileNotFoundError + If model_card.yaml is not found in the model folder. + """ + yaml_path = model_folder / "model_card.yaml" + + if not yaml_path.exists(): + raise FileNotFoundError( + f"model_card.yaml not found in {model_folder}. " + "Please create a model_card.yaml file with model metadata." + ) + + with open(yaml_path, "r") as f: + data = yaml.safe_load(f) + + # Extract fields with defaults + config = ModelCardConfig( + tags=data.get("tags", ["lan", "ssm", "hssm"]), + library_name=data.get("library_name", "onnx"), + license=data.get("license", "mit"), + title=data.get("title", "LAN Model"), + description=data.get( + "description", "Likelihood Approximation Network trained with LANfactory." + ), + architecture=data.get("architecture"), + training=data.get("training"), + usage_example=data.get("usage_example"), + ) + + # Try to fill in missing architecture/training from pickle configs + config = _fill_from_pickle_configs(model_folder, config) + + return config + + +def _fill_from_pickle_configs( + model_folder: Path, config: ModelCardConfig +) -> ModelCardConfig: + """Fill in missing config details from pickle files if available. + + Parameters + ---------- + model_folder : Path + Path to the model folder. + config : ModelCardConfig + Partially filled configuration. + + Returns + ------- + ModelCardConfig + Configuration with filled-in details from pickle files. + """ + # Try to find and load network config + if config.architecture is None: + network_config_files = list(model_folder.glob("*network_config.pickle")) + if network_config_files: + try: + with open(network_config_files[0], "rb") as f: + network_config = pickle.load(f) + config.architecture = { + "layer_sizes": network_config.get("layer_sizes"), + "activations": network_config.get("activations"), + "network_type": network_config.get("network_type"), + } + logger.info(f"Loaded architecture from {network_config_files[0]}") + except Exception as e: + logger.warning(f"Could not load network config: {e}") + + # Try to find and load train config + if config.training is None: + train_config_files = list(model_folder.glob("*train_config.pickle")) + if train_config_files: + try: + with open(train_config_files[0], "rb") as f: + train_config = pickle.load(f) + config.training = { + "epochs": train_config.get("n_epochs"), + "optimizer": train_config.get("optimizer"), + "learning_rate": train_config.get("learning_rate"), + "loss": train_config.get("loss"), + } + logger.info(f"Loaded training config from {train_config_files[0]}") + except Exception as e: + logger.warning(f"Could not load train config: {e}") + + return config + + +def generate_readme(config: ModelCardConfig, model_name: str | None = None) -> str: + """Generate HuggingFace-compatible README.md content. + + Parameters + ---------- + config : ModelCardConfig + Model card configuration. + model_name : str | None + Model name to include in usage example. + + Returns + ------- + str + README.md content with YAML frontmatter. + """ + # Build YAML frontmatter + frontmatter_dict: dict[str, Any] = { + "tags": config.tags, + "library_name": config.library_name, + "license": config.license, + } + + frontmatter = yaml.dump(frontmatter_dict, default_flow_style=False, sort_keys=False) + + # Build README content + lines = [ + "---", + frontmatter.strip(), + "---", + "", + f"# {config.title}", + "", + config.description, + "", + ] + + # Add architecture section if available + if config.architecture: + lines.extend( + [ + "## Architecture", + "", + ] + ) + if config.architecture.get("network_type"): + lines.append(f"- **Network Type:** {config.architecture['network_type']}") + if config.architecture.get("layer_sizes"): + lines.append(f"- **Layer Sizes:** {config.architecture['layer_sizes']}") + if config.architecture.get("activations"): + lines.append(f"- **Activations:** {config.architecture['activations']}") + lines.append("") + + # Add training section if available + if config.training: + lines.extend( + [ + "## Training", + "", + ] + ) + if config.training.get("epochs"): + lines.append(f"- **Epochs:** {config.training['epochs']}") + if config.training.get("optimizer"): + lines.append(f"- **Optimizer:** {config.training['optimizer']}") + if config.training.get("learning_rate"): + lines.append(f"- **Learning Rate:** {config.training['learning_rate']}") + if config.training.get("loss"): + lines.append(f"- **Loss:** {config.training['loss']}") + lines.append("") + + # Add usage example + lines.extend( + [ + "## Usage with HSSM", + "", + "```python", + ] + ) + + if config.usage_example: + lines.append(config.usage_example.strip()) + else: + # Default usage example + model_str = model_name or "ddm" + lines.extend( + [ + "import hssm", + f'model = hssm.HSSM(data=my_data, model="{model_str}", loglik_kind="approx_differentiable")', + ] + ) + + lines.extend( + [ + "```", + "", + "## Citation", + "", + "If you use this model, please cite:", + "", + "- [LANfactory](https://github.com/lnccbrown/LANfactory)", + "- [HSSM](https://github.com/lnccbrown/HSSM)", + "", + ] + ) + + return "\n".join(lines) + + +def write_readme( + model_folder: Path, config: ModelCardConfig, model_name: str | None = None +) -> Path: + """Generate and write README.md to model folder. + + Parameters + ---------- + model_folder : Path + Path to the model folder. + config : ModelCardConfig + Model card configuration. + model_name : str | None + Model name to include in usage example. + + Returns + ------- + Path + Path to the written README.md file. + """ + readme_content = generate_readme(config, model_name) + readme_path = model_folder / "README.md" + + with open(readme_path, "w") as f: + f.write(readme_content) + + logger.info(f"Generated README.md at {readme_path}") + return readme_path diff --git a/src/lanfactory/hf/upload.py b/src/lanfactory/hf/upload.py new file mode 100644 index 0000000..e51ee19 --- /dev/null +++ b/src/lanfactory/hf/upload.py @@ -0,0 +1,233 @@ +"""Upload utilities for HuggingFace Hub. + +This module provides functions to upload trained LANfactory models +to HuggingFace Hub with proper organization and metadata. +""" + +import logging +import shutil +import tempfile +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Default repository for official HSSM models +DEFAULT_REPO_ID = "franklab/HSSM" + +# Default file patterns to include in uploads +DEFAULT_INCLUDE_PATTERNS = [ + "*.onnx", + "*.pt", + "*.jax", + "*_config.pickle", + "*.csv", + "model_card.yaml", +] + + +def upload_model( + model_folder: Path, + network_type: str, + model_name: str, + repo_id: str = DEFAULT_REPO_ID, + commit_message: str = "Upload model", + private: bool = False, + create_repo: bool = False, + include_patterns: list[str] | None = None, + exclude_patterns: list[str] | None = None, + revision: str | None = None, + token: str | None = None, + dry_run: bool = False, +) -> str | None: + """Upload a trained model to HuggingFace Hub. + + Parameters + ---------- + model_folder : Path + Path to the folder containing trained model artifacts. + network_type : str + Network type (e.g., "lan", "cpn", "opn"). + model_name : str + Model name (e.g., "ddm", "angle"). + repo_id : str + HuggingFace repository ID (default: "franklab/HSSM"). + commit_message : str + Git commit message for the upload. + private : bool + Whether to create a private repository. + create_repo : bool + Whether to create the repository if it doesn't exist. + include_patterns : list[str] | None + Glob patterns for files to include. + exclude_patterns : list[str] | None + Glob patterns for files to exclude. + revision : str | None + Branch or tag name for versioning. + token : str | None + HuggingFace API token. + dry_run : bool + If True, show what would be uploaded without uploading. + + Returns + ------- + str | None + URL of the uploaded model, or None if dry_run is True. + + Raises + ------ + ImportError + If huggingface_hub is not installed. + FileNotFoundError + If model_folder doesn't exist or is missing required files. + ValueError + If network_type is not valid. + """ + try: + from huggingface_hub import HfApi, create_repo as hf_create_repo + except ImportError: + raise ImportError( + "huggingface_hub is required for HuggingFace uploads. " + "Install it with: pip install lanfactory[hf]" + ) + + # Validate inputs + model_folder = Path(model_folder) + if not model_folder.exists(): + raise FileNotFoundError(f"Model folder does not exist: {model_folder}") + + valid_network_types = ["lan", "cpn", "opn"] + if network_type not in valid_network_types: + raise ValueError( + f"network_type must be one of {valid_network_types}, got: {network_type}" + ) + + # Check for model_card.yaml + model_card_path = model_folder / "model_card.yaml" + if not model_card_path.exists(): + raise FileNotFoundError( + f"model_card.yaml not found in {model_folder}. " + "Please create a model_card.yaml file with model metadata." + ) + + # Use default patterns if not specified + if include_patterns is None: + include_patterns = DEFAULT_INCLUDE_PATTERNS + + # Collect files to upload + files_to_upload = _collect_files(model_folder, include_patterns, exclude_patterns) + + if not files_to_upload: + raise FileNotFoundError( + f"No files matching patterns {include_patterns} found in {model_folder}" + ) + + # Log what will be uploaded + path_in_repo = f"{network_type}/{model_name}" + logger.info(f"Upload destination: {repo_id}/{path_in_repo}") + logger.info(f"Files to upload ({len(files_to_upload)}):") + for f in files_to_upload: + logger.info(f" - {f.name}") + + if dry_run: + logger.info("DRY RUN: No files were uploaded.") + print( + f"\nDRY RUN: Would upload {len(files_to_upload)} files to {repo_id}/{path_in_repo}" + ) + for f in files_to_upload: + print(f" - {f.name}") + return None + + # Initialize API + api = HfApi(token=token) + + # Create repo if requested + if create_repo: + try: + hf_create_repo( + repo_id=repo_id, + repo_type="model", + private=private, + exist_ok=True, + token=token, + ) + logger.info(f"Repository created/verified: {repo_id}") + except Exception as e: + logger.error(f"Failed to create repository: {e}") + raise + + # Generate README.md from model_card.yaml + from lanfactory.hf.model_card import load_model_card_yaml, write_readme + + config = load_model_card_yaml(model_folder) + readme_path = write_readme(model_folder, config, model_name) + files_to_upload.append(readme_path) + + # Create a temporary directory with files organized for upload + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Copy files to temp directory + for file_path in files_to_upload: + dest = tmp_path / file_path.name + shutil.copy2(file_path, dest) + + # Upload the folder + try: + api.upload_folder( + folder_path=str(tmp_path), + repo_id=repo_id, + path_in_repo=path_in_repo, + commit_message=commit_message, + revision=revision, + token=token, + ) + except Exception as e: + logger.error(f"Upload failed: {e}") + raise + + # Construct URL + url = f"https://huggingface.co/{repo_id}/tree/main/{path_in_repo}" + logger.info(f"Upload successful: {url}") + print("\nUpload successful!") + print(f"View your model at: {url}") + + return url + + +def _collect_files( + folder: Path, + include_patterns: list[str], + exclude_patterns: list[str] | None, +) -> list[Path]: + """Collect files matching include patterns and not matching exclude patterns. + + Parameters + ---------- + folder : Path + Folder to search for files. + include_patterns : list[str] + Glob patterns for files to include. + exclude_patterns : list[str] | None + Glob patterns for files to exclude. + + Returns + ------- + list[Path] + List of file paths to upload. + """ + files = set() + + # Collect files matching include patterns + for pattern in include_patterns: + files.update(folder.glob(pattern)) + + # Remove files matching exclude patterns + if exclude_patterns: + for pattern in exclude_patterns: + excluded = set(folder.glob(pattern)) + files -= excluded + + # Filter to only regular files (not directories) + files = [f for f in files if f.is_file()] + + return sorted(files) diff --git a/tests/cli/test_hf_cli.py b/tests/cli/test_hf_cli.py new file mode 100644 index 0000000..eee4a64 --- /dev/null +++ b/tests/cli/test_hf_cli.py @@ -0,0 +1,166 @@ +"""CLI smoke tests for upload-hf and download-hf commands.""" + +import subprocess + +import yaml + + +class TestUploadHfCliHelp: + """Tests for upload-hf CLI help and argument validation.""" + + def test_help_command(self): + """Test that --help works.""" + result = subprocess.run( + ["upload-hf", "--help"], + capture_output=True, + text=True, + check=False, + ) + assert result.returncode == 0 + assert "Upload a trained LANfactory model" in result.stdout + assert "--model-folder" in result.stdout + assert "--network-type" in result.stdout + assert "--model-name" in result.stdout + + def test_missing_required_args(self): + """Test that missing required args causes error.""" + result = subprocess.run( + ["upload-hf"], + capture_output=True, + text=True, + check=False, + ) + assert result.returncode != 0 + + def test_invalid_network_type(self, tmp_path): + """Test that invalid network type causes error.""" + # Create a dummy model_card.yaml + yaml_path = tmp_path / "model_card.yaml" + with open(yaml_path, "w", encoding="utf-8") as f: + yaml.dump({"title": "Test"}, f) + + result = subprocess.run( + [ + "upload-hf", + "--model-folder", + str(tmp_path), + "--network-type", + "invalid", + "--model-name", + "ddm", + "--dry-run", + ], + capture_output=True, + text=True, + check=False, + ) + assert result.returncode != 0 + + +class TestDownloadHfCliHelp: + """Tests for download-hf CLI help and argument validation.""" + + def test_help_command(self): + """Test that --help works.""" + result = subprocess.run( + ["download-hf", "--help"], + capture_output=True, + text=True, + check=False, + ) + assert result.returncode == 0 + assert "Download a LANfactory model" in result.stdout + assert "--network-type" in result.stdout + assert "--model-name" in result.stdout + assert "--output-folder" in result.stdout + + def test_missing_required_args(self): + """Test that missing required args causes error.""" + result = subprocess.run( + ["download-hf"], + capture_output=True, + text=True, + check=False, + ) + assert result.returncode != 0 + + def test_invalid_network_type(self, tmp_path): + """Test that invalid network type causes error.""" + result = subprocess.run( + [ + "download-hf", + "--network-type", + "invalid", + "--model-name", + "ddm", + "--output-folder", + str(tmp_path / "output"), + ], + capture_output=True, + text=True, + check=False, + ) + assert result.returncode != 0 + + +class TestUploadHfDryRun: + """Tests for upload-hf dry run functionality.""" + + def test_dry_run_with_model_card(self, tmp_path): + """Test dry run with valid model_card.yaml.""" + # Create model_card.yaml + yaml_content = { + "tags": ["lan", "ssm", "ddm", "hssm"], + "library_name": "onnx", + "title": "Test Model", + "description": "Test description", + } + yaml_path = tmp_path / "model_card.yaml" + with open(yaml_path, "w", encoding="utf-8") as f: + yaml.dump(yaml_content, f) + + # Create a dummy model file + (tmp_path / "model.onnx").write_text("dummy content") + + result = subprocess.run( + [ + "upload-hf", + "--model-folder", + str(tmp_path), + "--network-type", + "lan", + "--model-name", + "test-model", + "--dry-run", + ], + capture_output=True, + text=True, + check=False, + ) + + assert result.returncode == 0 + assert "DRY RUN" in result.stdout + + def test_dry_run_missing_model_card(self, tmp_path): + """Test dry run fails when model_card.yaml is missing.""" + result = subprocess.run( + [ + "upload-hf", + "--model-folder", + str(tmp_path), + "--network-type", + "lan", + "--model-name", + "test-model", + "--dry-run", + ], + capture_output=True, + text=True, + check=False, + ) + + assert result.returncode != 0 + assert ( + "model_card.yaml not found" in result.stderr + or "model_card.yaml not found" in result.stdout + ) diff --git a/tests/hf/__init__.py b/tests/hf/__init__.py new file mode 100644 index 0000000..de11d0e --- /dev/null +++ b/tests/hf/__init__.py @@ -0,0 +1 @@ +"""Tests for lanfactory.hf module.""" diff --git a/tests/hf/test_download.py b/tests/hf/test_download.py new file mode 100644 index 0000000..02c71f9 --- /dev/null +++ b/tests/hf/test_download.py @@ -0,0 +1,166 @@ +"""Tests for download.py module.""" + +from unittest.mock import patch + +import pytest + +from lanfactory.hf.download import ( + DEFAULT_REPO_ID, + download_model, +) + + +class TestDownloadModel: + """Tests for download_model function.""" + + def test_raises_if_invalid_network_type(self, tmp_path): + """Test raises ValueError for invalid network_type.""" + with pytest.raises(ValueError, match="network_type must be one of"): + download_model( + network_type="invalid", + model_name="ddm", + output_folder=tmp_path / "output", + ) + + def test_raises_if_output_exists_without_force(self, tmp_path): + """Test raises FileExistsError if output folder exists and force=False.""" + output_folder = tmp_path / "output" + output_folder.mkdir() + + with pytest.raises(FileExistsError, match="already exists"): + download_model( + network_type="lan", + model_name="ddm", + output_folder=output_folder, + force=False, + ) + + @patch("huggingface_hub.list_repo_files") + @patch("huggingface_hub.hf_hub_download") + def test_downloads_files_to_output_folder( + self, mock_download, mock_list_files, tmp_path + ): + """Test files are downloaded to output folder.""" + output_folder = tmp_path / "output" + + # Mock list of files in repo + mock_list_files.return_value = [ + "lan/ddm/model.onnx", + "lan/ddm/config.pickle", + "cpn/angle/model.onnx", # Should not be downloaded + ] + + # Mock download - return a temp file path + temp_file = tmp_path / "temp_download.onnx" + temp_file.write_text("onnx content") + mock_download.return_value = str(temp_file) + + result = download_model( + network_type="lan", + model_name="ddm", + output_folder=output_folder, + ) + + assert result == output_folder + assert output_folder.exists() + + # Check that download was called for correct files + assert mock_download.call_count == 2 + + @patch("huggingface_hub.list_repo_files") + def test_raises_if_no_files_found(self, mock_list_files, tmp_path): + """Test raises FileNotFoundError if no files match.""" + output_folder = tmp_path / "output" + + # Mock empty file list for the path + mock_list_files.return_value = [ + "cpn/angle/model.onnx", # Wrong network type + ] + + with pytest.raises(FileNotFoundError, match="No files found"): + download_model( + network_type="lan", + model_name="ddm", + output_folder=output_folder, + ) + + @patch("huggingface_hub.list_repo_files") + @patch("huggingface_hub.hf_hub_download") + def test_applies_include_patterns(self, mock_download, mock_list_files, tmp_path): + """Test include patterns filter downloaded files.""" + output_folder = tmp_path / "output" + + mock_list_files.return_value = [ + "lan/ddm/model.onnx", + "lan/ddm/history.csv", + ] + + temp_file = tmp_path / "temp.onnx" + temp_file.write_text("content") + mock_download.return_value = str(temp_file) + + download_model( + network_type="lan", + model_name="ddm", + output_folder=output_folder, + include_patterns=["*.onnx"], + ) + + # Only .onnx file should be downloaded + assert mock_download.call_count == 1 + + @patch("huggingface_hub.list_repo_files") + @patch("huggingface_hub.hf_hub_download") + def test_applies_exclude_patterns(self, mock_download, mock_list_files, tmp_path): + """Test exclude patterns filter downloaded files.""" + output_folder = tmp_path / "output" + + mock_list_files.return_value = [ + "lan/ddm/model.onnx", + "lan/ddm/history.csv", + ] + + temp_file = tmp_path / "temp.onnx" + temp_file.write_text("content") + mock_download.return_value = str(temp_file) + + download_model( + network_type="lan", + model_name="ddm", + output_folder=output_folder, + exclude_patterns=["*.csv"], + ) + + # Only .onnx file should be downloaded + assert mock_download.call_count == 1 + + @patch("huggingface_hub.list_repo_files") + @patch("huggingface_hub.hf_hub_download") + def test_force_overwrites_existing(self, mock_download, mock_list_files, tmp_path): + """Test force=True allows overwriting existing folder.""" + output_folder = tmp_path / "output" + output_folder.mkdir() + + mock_list_files.return_value = ["lan/ddm/model.onnx"] + + temp_file = tmp_path / "temp.onnx" + temp_file.write_text("content") + mock_download.return_value = str(temp_file) + + # Should not raise + result = download_model( + network_type="lan", + model_name="ddm", + output_folder=output_folder, + force=True, + ) + + assert result == output_folder + + +class TestDefaults: + """Tests for default values.""" + + def test_default_repo_id(self): + """Test default repo ID is franklab/HSSM.""" + assert DEFAULT_REPO_ID == "franklab/HSSM" diff --git a/tests/hf/test_model_card.py b/tests/hf/test_model_card.py new file mode 100644 index 0000000..4680819 --- /dev/null +++ b/tests/hf/test_model_card.py @@ -0,0 +1,211 @@ +"""Tests for model_card.py module.""" + +import pickle + +import pytest +import yaml + +from lanfactory.hf.model_card import ( + ModelCardConfig, + generate_readme, + load_model_card_yaml, + write_readme, +) + + +class TestModelCardConfig: + """Tests for ModelCardConfig dataclass.""" + + def test_default_values(self): + """Test default values are set correctly.""" + config = ModelCardConfig() + assert config.tags == ["lan", "ssm", "hssm"] + assert config.library_name == "onnx" + assert config.license == "mit" + assert config.title == "LAN Model" + assert config.architecture is None + assert config.training is None + assert config.usage_example is None + + def test_custom_values(self): + """Test custom values are set correctly.""" + config = ModelCardConfig( + tags=["lan", "ssm", "ddm"], + title="DDM LAN Model", + description="Custom description", + architecture={"layer_sizes": [100, 100, 1]}, + ) + assert config.tags == ["lan", "ssm", "ddm"] + assert config.title == "DDM LAN Model" + assert config.description == "Custom description" + assert config.architecture == {"layer_sizes": [100, 100, 1]} + + +class TestLoadModelCardYaml: + """Tests for load_model_card_yaml function.""" + + def test_load_valid_yaml(self, tmp_path): + """Test loading a valid model_card.yaml file.""" + yaml_content = { + "tags": ["lan", "ssm", "ddm", "hssm"], + "library_name": "onnx", + "license": "mit", + "title": "LAN Model for DDM", + "description": "Test description", + } + + yaml_path = tmp_path / "model_card.yaml" + with open(yaml_path, "w") as f: + yaml.dump(yaml_content, f) + + config = load_model_card_yaml(tmp_path) + + assert config.tags == ["lan", "ssm", "ddm", "hssm"] + assert config.library_name == "onnx" + assert config.title == "LAN Model for DDM" + assert config.description == "Test description" + + def test_load_yaml_with_defaults(self, tmp_path): + """Test loading a minimal YAML file uses defaults.""" + yaml_content = {"title": "Minimal Model"} + + yaml_path = tmp_path / "model_card.yaml" + with open(yaml_path, "w") as f: + yaml.dump(yaml_content, f) + + config = load_model_card_yaml(tmp_path) + + assert config.title == "Minimal Model" + assert config.tags == ["lan", "ssm", "hssm"] # Default + assert config.library_name == "onnx" # Default + + def test_load_yaml_not_found(self, tmp_path): + """Test FileNotFoundError when YAML doesn't exist.""" + with pytest.raises(FileNotFoundError, match="model_card.yaml not found"): + load_model_card_yaml(tmp_path) + + def test_load_yaml_fills_from_pickle(self, tmp_path): + """Test that architecture is filled from pickle config.""" + # Create minimal YAML + yaml_content = {"title": "Test Model"} + yaml_path = tmp_path / "model_card.yaml" + with open(yaml_path, "w") as f: + yaml.dump(yaml_content, f) + + # Create network config pickle + network_config = { + "layer_sizes": [100, 100, 1], + "activations": ["tanh", "tanh", "linear"], + "network_type": "lan", + } + pickle_path = tmp_path / "test_network_config.pickle" + with open(pickle_path, "wb") as f: + pickle.dump(network_config, f) + + config = load_model_card_yaml(tmp_path) + + assert config.architecture is not None + assert config.architecture["layer_sizes"] == [100, 100, 1] + assert config.architecture["network_type"] == "lan" + + +class TestGenerateReadme: + """Tests for generate_readme function.""" + + def test_generates_valid_frontmatter(self): + """Test that generated README has valid YAML frontmatter.""" + config = ModelCardConfig( + tags=["lan", "ssm", "ddm"], + title="Test Model", + ) + + readme = generate_readme(config) + + # Check frontmatter markers + assert readme.startswith("---\n") + assert "\n---\n" in readme + + # Extract and parse frontmatter + parts = readme.split("---") + frontmatter = yaml.safe_load(parts[1]) + + assert frontmatter["tags"] == ["lan", "ssm", "ddm"] + assert frontmatter["library_name"] == "onnx" + assert frontmatter["license"] == "mit" + + def test_includes_title_and_description(self): + """Test that README includes title and description.""" + config = ModelCardConfig( + title="My Model", + description="My description", + ) + + readme = generate_readme(config) + + assert "# My Model" in readme + assert "My description" in readme + + def test_includes_architecture_section(self): + """Test that architecture section is included when provided.""" + config = ModelCardConfig( + architecture={ + "layer_sizes": [100, 100, 1], + "activations": ["tanh", "tanh", "linear"], + "network_type": "lan", + } + ) + + readme = generate_readme(config) + + assert "## Architecture" in readme + assert "**Network Type:** lan" in readme + assert "[100, 100, 1]" in readme + + def test_includes_training_section(self): + """Test that training section is included when provided.""" + config = ModelCardConfig( + training={ + "epochs": 20, + "optimizer": "adam", + "learning_rate": 0.001, + } + ) + + readme = generate_readme(config) + + assert "## Training" in readme + assert "**Epochs:** 20" in readme + assert "**Optimizer:** adam" in readme + + def test_includes_usage_example(self): + """Test that usage example section is included.""" + config = ModelCardConfig() + + readme = generate_readme(config, model_name="ddm") + + assert "## Usage with HSSM" in readme + assert 'model="ddm"' in readme + + def test_custom_usage_example(self): + """Test that custom usage example is used when provided.""" + config = ModelCardConfig(usage_example="custom_code_here()") + + readme = generate_readme(config) + + assert "custom_code_here()" in readme + + +class TestWriteReadme: + """Tests for write_readme function.""" + + def test_writes_readme_file(self, tmp_path): + """Test that README.md is written to disk.""" + config = ModelCardConfig(title="Test Model") + + readme_path = write_readme(tmp_path, config) + + assert readme_path.exists() + assert readme_path.name == "README.md" + + content = readme_path.read_text() + assert "# Test Model" in content diff --git a/tests/hf/test_upload.py b/tests/hf/test_upload.py new file mode 100644 index 0000000..dad5e94 --- /dev/null +++ b/tests/hf/test_upload.py @@ -0,0 +1,186 @@ +"""Tests for upload.py module.""" + +from unittest.mock import MagicMock, patch + +import pytest +import yaml + +from lanfactory.hf.upload import ( + DEFAULT_INCLUDE_PATTERNS, + DEFAULT_REPO_ID, + _collect_files, + upload_model, +) + + +class TestCollectFiles: + """Tests for _collect_files function.""" + + def test_collects_matching_files(self, tmp_path): + """Test collecting files matching patterns.""" + # Create test files + (tmp_path / "model.onnx").write_text("onnx content") + (tmp_path / "model.pt").write_text("pytorch content") + (tmp_path / "config.pickle").write_text("config content") + (tmp_path / "other.txt").write_text("other content") + + files = _collect_files( + tmp_path, + include_patterns=["*.onnx", "*.pt"], + exclude_patterns=None, + ) + + filenames = [f.name for f in files] + assert "model.onnx" in filenames + assert "model.pt" in filenames + assert "other.txt" not in filenames + + def test_excludes_files(self, tmp_path): + """Test excluding files matching patterns.""" + (tmp_path / "model.onnx").write_text("content") + (tmp_path / "backup.onnx").write_text("content") + + files = _collect_files( + tmp_path, + include_patterns=["*.onnx"], + exclude_patterns=["backup*"], + ) + + filenames = [f.name for f in files] + assert "model.onnx" in filenames + assert "backup.onnx" not in filenames + + def test_returns_empty_for_no_matches(self, tmp_path): + """Test returns empty list when no files match.""" + (tmp_path / "other.txt").write_text("content") + + files = _collect_files( + tmp_path, + include_patterns=["*.onnx"], + exclude_patterns=None, + ) + + assert files == [] + + +class TestUploadModel: + """Tests for upload_model function.""" + + def test_raises_if_folder_not_exists(self, tmp_path): + """Test raises FileNotFoundError if folder doesn't exist.""" + non_existent = tmp_path / "non_existent" + + with pytest.raises(FileNotFoundError, match="does not exist"): + upload_model( + model_folder=non_existent, + network_type="lan", + model_name="ddm", + ) + + def test_raises_if_invalid_network_type(self, tmp_path): + """Test raises ValueError for invalid network_type.""" + with pytest.raises(ValueError, match="network_type must be one of"): + upload_model( + model_folder=tmp_path, + network_type="invalid", + model_name="ddm", + ) + + def test_raises_if_model_card_missing(self, tmp_path): + """Test raises FileNotFoundError if model_card.yaml is missing.""" + with pytest.raises(FileNotFoundError, match="model_card.yaml not found"): + upload_model( + model_folder=tmp_path, + network_type="lan", + model_name="ddm", + ) + + def test_dry_run_does_not_upload(self, tmp_path): + """Test dry_run shows files but doesn't upload.""" + # Create model_card.yaml + yaml_content = { + "tags": ["lan", "ssm", "ddm"], + "title": "Test Model", + } + yaml_path = tmp_path / "model_card.yaml" + with open(yaml_path, "w") as f: + yaml.dump(yaml_content, f) + + # Create a model file + (tmp_path / "model.onnx").write_text("onnx content") + + result = upload_model( + model_folder=tmp_path, + network_type="lan", + model_name="ddm", + dry_run=True, + ) + + assert result is None + + @patch("huggingface_hub.HfApi") + @patch("huggingface_hub.create_repo") + def test_creates_repo_when_requested( + self, mock_create_repo, mock_api_class, tmp_path + ): + """Test repository is created when create_repo=True.""" + # Create model_card.yaml + yaml_content = {"tags": ["lan", "ssm"], "title": "Test"} + with open(tmp_path / "model_card.yaml", "w") as f: + yaml.dump(yaml_content, f) + (tmp_path / "model.onnx").write_text("content") + + # Mock API + mock_api = MagicMock() + mock_api_class.return_value = mock_api + + upload_model( + model_folder=tmp_path, + network_type="lan", + model_name="ddm", + create_repo=True, + token="fake_token", + ) + + mock_create_repo.assert_called_once() + + @patch("huggingface_hub.HfApi") + @patch("huggingface_hub.create_repo") + def test_uploads_to_correct_path(self, mock_create_repo, mock_api_class, tmp_path): + """Test files are uploaded to correct path in repo.""" + # Create model_card.yaml + yaml_content = {"tags": ["lan", "ssm"], "title": "Test"} + with open(tmp_path / "model_card.yaml", "w") as f: + yaml.dump(yaml_content, f) + (tmp_path / "model.onnx").write_text("content") + + # Mock API + mock_api = MagicMock() + mock_api_class.return_value = mock_api + + upload_model( + model_folder=tmp_path, + network_type="lan", + model_name="ddm", + repo_id="test/repo", + ) + + # Check upload_folder was called with correct path + mock_api.upload_folder.assert_called_once() + call_kwargs = mock_api.upload_folder.call_args[1] + assert call_kwargs["path_in_repo"] == "lan/ddm" + assert call_kwargs["repo_id"] == "test/repo" + + +class TestDefaults: + """Tests for default values.""" + + def test_default_repo_id(self): + """Test default repo ID is franklab/HSSM.""" + assert DEFAULT_REPO_ID == "franklab/HSSM" + + def test_default_include_patterns(self): + """Test default include patterns.""" + assert "*.onnx" in DEFAULT_INCLUDE_PATTERNS + assert "*.pt" in DEFAULT_INCLUDE_PATTERNS + assert "model_card.yaml" in DEFAULT_INCLUDE_PATTERNS From 9c5c24b75bb5e79e225aefa8558a5d6e20105bed Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 16 Mar 2026 23:03:47 -0400 Subject: [PATCH 03/11] fix rando issue with tests --- tests/cli/test_hf_cli.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/cli/test_hf_cli.py b/tests/cli/test_hf_cli.py index eee4a64..a1d8c11 100644 --- a/tests/cli/test_hf_cli.py +++ b/tests/cli/test_hf_cli.py @@ -1,9 +1,12 @@ """CLI smoke tests for upload-hf and download-hf commands.""" +import os import subprocess import yaml +_PLAIN_TEXT_ENV = {**os.environ, "NO_COLOR": "1", "COLUMNS": "200"} + class TestUploadHfCliHelp: """Tests for upload-hf CLI help and argument validation.""" @@ -15,6 +18,7 @@ def test_help_command(self): capture_output=True, text=True, check=False, + env=_PLAIN_TEXT_ENV, ) assert result.returncode == 0 assert "Upload a trained LANfactory model" in result.stdout @@ -67,6 +71,7 @@ def test_help_command(self): capture_output=True, text=True, check=False, + env=_PLAIN_TEXT_ENV, ) assert result.returncode == 0 assert "Download a LANfactory model" in result.stdout From 9384592df404a726a88f4128e0ad144e0d31b1db Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 16 Mar 2026 23:18:46 -0400 Subject: [PATCH 04/11] fix rando issue with tests 2 --- tests/cli/test_hf_cli.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/cli/test_hf_cli.py b/tests/cli/test_hf_cli.py index a1d8c11..d3b3954 100644 --- a/tests/cli/test_hf_cli.py +++ b/tests/cli/test_hf_cli.py @@ -1,11 +1,17 @@ """CLI smoke tests for upload-hf and download-hf commands.""" import os +import re import subprocess import yaml -_PLAIN_TEXT_ENV = {**os.environ, "NO_COLOR": "1", "COLUMNS": "200"} +_ANSI_RE = re.compile(r"\x1b\[[\d;]*m") +_WIDE_ENV = {**os.environ, "COLUMNS": "200"} + + +def _strip_ansi(text: str) -> str: + return _ANSI_RE.sub("", text) class TestUploadHfCliHelp: @@ -18,13 +24,14 @@ def test_help_command(self): capture_output=True, text=True, check=False, - env=_PLAIN_TEXT_ENV, + env=_WIDE_ENV, ) assert result.returncode == 0 - assert "Upload a trained LANfactory model" in result.stdout - assert "--model-folder" in result.stdout - assert "--network-type" in result.stdout - assert "--model-name" in result.stdout + stdout = _strip_ansi(result.stdout) + assert "Upload a trained LANfactory model" in stdout + assert "--model-folder" in stdout + assert "--network-type" in stdout + assert "--model-name" in stdout def test_missing_required_args(self): """Test that missing required args causes error.""" @@ -71,13 +78,14 @@ def test_help_command(self): capture_output=True, text=True, check=False, - env=_PLAIN_TEXT_ENV, + env=_WIDE_ENV, ) assert result.returncode == 0 - assert "Download a LANfactory model" in result.stdout - assert "--network-type" in result.stdout - assert "--model-name" in result.stdout - assert "--output-folder" in result.stdout + stdout = _strip_ansi(result.stdout) + assert "Download a LANfactory model" in stdout + assert "--network-type" in stdout + assert "--model-name" in stdout + assert "--output-folder" in stdout def test_missing_required_args(self): """Test that missing required args causes error.""" From c0eeab1b0fed5864b6ba136d77d544713b4a8227 Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 17 Mar 2026 00:05:23 -0400 Subject: [PATCH 05/11] add more robustifications --- src/lanfactory/cli/download_hf.py | 10 ++++------ src/lanfactory/cli/upload_hf.py | 10 ++++------ src/lanfactory/config/__init__.py | 4 ++++ src/lanfactory/hf/__init__.py | 13 +++++++------ src/lanfactory/hf/download.py | 10 ++++------ src/lanfactory/hf/upload.py | 26 ++++++++++++-------------- src/lanfactory/trainers/__init__.py | 19 +++++++++++++++++++ 7 files changed, 54 insertions(+), 38 deletions(-) diff --git a/src/lanfactory/cli/download_hf.py b/src/lanfactory/cli/download_hf.py index 1540f1a..262f776 100644 --- a/src/lanfactory/cli/download_hf.py +++ b/src/lanfactory/cli/download_hf.py @@ -13,10 +13,9 @@ import typer -app = typer.Typer() +from lanfactory.hf import DEFAULT_REPO_ID, VALID_NETWORK_TYPES -# Default repository for official HSSM models -DEFAULT_REPO_ID = "franklab/HSSM" +app = typer.Typer() @app.command() @@ -97,10 +96,9 @@ def main( logger = logging.getLogger(__name__) # Validate network_type - valid_network_types = ["lan", "cpn", "opn"] - if network_type not in valid_network_types: + if network_type not in VALID_NETWORK_TYPES: raise typer.BadParameter( - f"network_type must be one of {valid_network_types}, got: {network_type}" + f"network_type must be one of {list(VALID_NETWORK_TYPES)}, got: {network_type}" ) # Parse patterns diff --git a/src/lanfactory/cli/upload_hf.py b/src/lanfactory/cli/upload_hf.py index b61f4d8..538b1db 100644 --- a/src/lanfactory/cli/upload_hf.py +++ b/src/lanfactory/cli/upload_hf.py @@ -13,10 +13,9 @@ import typer -app = typer.Typer() +from lanfactory.hf import DEFAULT_REPO_ID, VALID_NETWORK_TYPES -# Default repository for official HSSM models -DEFAULT_REPO_ID = "franklab/HSSM" +app = typer.Typer() @app.command() @@ -118,10 +117,9 @@ def main( logger = logging.getLogger(__name__) # Validate network_type - valid_network_types = ["lan", "cpn", "opn"] - if network_type not in valid_network_types: + if network_type not in VALID_NETWORK_TYPES: raise typer.BadParameter( - f"network_type must be one of {valid_network_types}, got: {network_type}" + f"network_type must be one of {list(VALID_NETWORK_TYPES)}, got: {network_type}" ) # Parse patterns diff --git a/src/lanfactory/config/__init__.py b/src/lanfactory/config/__init__.py index 2d69f92..1878766 100755 --- a/src/lanfactory/config/__init__.py +++ b/src/lanfactory/config/__init__.py @@ -1,17 +1,21 @@ from .network_configs import ( network_config_mlp, + network_config_choice_prob, network_config_opn, network_config_cpn, train_config_mlp, + train_config_choice_prob, train_config_opn, train_config_cpn, ) __all__ = [ "network_config_mlp", + "network_config_choice_prob", "network_config_opn", "network_config_cpn", "train_config_mlp", + "train_config_choice_prob", "train_config_opn", "train_config_cpn", ] diff --git a/src/lanfactory/hf/__init__.py b/src/lanfactory/hf/__init__.py index ef4e76b..7dbda0a 100644 --- a/src/lanfactory/hf/__init__.py +++ b/src/lanfactory/hf/__init__.py @@ -4,19 +4,20 @@ downloading models from HuggingFace Hub. """ -from lanfactory.hf.model_card import ( +DEFAULT_REPO_ID = "franklab/HSSM" +VALID_NETWORK_TYPES = ("lan", "cpn", "opn") + +from lanfactory.hf.model_card import ( # noqa: E402 load_model_card_yaml, generate_readme, ModelCardConfig, ) -from lanfactory.hf.upload import upload_model -from lanfactory.hf.download import download_model - -# Default repository for official HSSM models -DEFAULT_REPO_ID = "franklab/HSSM" +from lanfactory.hf.upload import upload_model # noqa: E402 +from lanfactory.hf.download import download_model # noqa: E402 __all__ = [ "DEFAULT_REPO_ID", + "VALID_NETWORK_TYPES", "load_model_card_yaml", "generate_readme", "ModelCardConfig", diff --git a/src/lanfactory/hf/download.py b/src/lanfactory/hf/download.py index d0c9151..78266cf 100644 --- a/src/lanfactory/hf/download.py +++ b/src/lanfactory/hf/download.py @@ -8,10 +8,9 @@ import shutil from pathlib import Path -logger = logging.getLogger(__name__) +from lanfactory.hf import DEFAULT_REPO_ID, VALID_NETWORK_TYPES -# Default repository for official HSSM models -DEFAULT_REPO_ID = "franklab/HSSM" +logger = logging.getLogger(__name__) def download_model( @@ -71,10 +70,9 @@ def download_model( ) # Validate inputs - valid_network_types = ["lan", "cpn", "opn"] - if network_type not in valid_network_types: + if network_type not in VALID_NETWORK_TYPES: raise ValueError( - f"network_type must be one of {valid_network_types}, got: {network_type}" + f"network_type must be one of {list(VALID_NETWORK_TYPES)}, got: {network_type}" ) output_folder = Path(output_folder) diff --git a/src/lanfactory/hf/upload.py b/src/lanfactory/hf/upload.py index e51ee19..a8fd9ee 100644 --- a/src/lanfactory/hf/upload.py +++ b/src/lanfactory/hf/upload.py @@ -9,10 +9,9 @@ import tempfile from pathlib import Path -logger = logging.getLogger(__name__) +from lanfactory.hf import DEFAULT_REPO_ID, VALID_NETWORK_TYPES -# Default repository for official HSSM models -DEFAULT_REPO_ID = "franklab/HSSM" +logger = logging.getLogger(__name__) # Default file patterns to include in uploads DEFAULT_INCLUDE_PATTERNS = [ @@ -82,23 +81,14 @@ def upload_model( ValueError If network_type is not valid. """ - try: - from huggingface_hub import HfApi, create_repo as hf_create_repo - except ImportError: - raise ImportError( - "huggingface_hub is required for HuggingFace uploads. " - "Install it with: pip install lanfactory[hf]" - ) - # Validate inputs model_folder = Path(model_folder) if not model_folder.exists(): raise FileNotFoundError(f"Model folder does not exist: {model_folder}") - valid_network_types = ["lan", "cpn", "opn"] - if network_type not in valid_network_types: + if network_type not in VALID_NETWORK_TYPES: raise ValueError( - f"network_type must be one of {valid_network_types}, got: {network_type}" + f"network_type must be one of {list(VALID_NETWORK_TYPES)}, got: {network_type}" ) # Check for model_card.yaml @@ -137,6 +127,14 @@ def upload_model( print(f" - {f.name}") return None + try: + from huggingface_hub import HfApi, create_repo as hf_create_repo + except ImportError as exc: + raise ImportError( + "huggingface_hub is required for HuggingFace uploads. " + "Install it with: pip install lanfactory[hf]" + ) from exc + # Initialize API api = HfApi(token=token) diff --git a/src/lanfactory/trainers/__init__.py b/src/lanfactory/trainers/__init__.py index d2780b5..ecd495d 100755 --- a/src/lanfactory/trainers/__init__.py +++ b/src/lanfactory/trainers/__init__.py @@ -1,3 +1,5 @@ +import warnings + from .torch_mlp import ( DatasetTorch, TorchMLP, @@ -26,3 +28,20 @@ "JaxMLP", "ModelTrainerJaxMLP", ] + +_DEPRECATED_ALIASES = { + "MLPJax": "JaxMLP", + "MLPJaxFactory": "JaxMLPFactory", +} + + +def __getattr__(name: str): + if name in _DEPRECATED_ALIASES: + new_name = _DEPRECATED_ALIASES[name] + warnings.warn( + f"{name} is deprecated, use {new_name} instead", + DeprecationWarning, + stacklevel=2, + ) + return globals()[new_name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") From d32908349572da51ef0cdeb0fb1fb13e2514c78e Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 17 Mar 2026 00:11:37 -0400 Subject: [PATCH 06/11] more test fixes --- src/lanfactory/hf/download.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/lanfactory/hf/download.py b/src/lanfactory/hf/download.py index 78266cf..3bf8b5d 100644 --- a/src/lanfactory/hf/download.py +++ b/src/lanfactory/hf/download.py @@ -61,14 +61,6 @@ def download_model( FileExistsError If output_folder exists and force is False. """ - try: - from huggingface_hub import hf_hub_download, list_repo_files - except ImportError: - raise ImportError( - "huggingface_hub is required for HuggingFace downloads. " - "Install it with: pip install lanfactory[hf]" - ) - # Validate inputs if network_type not in VALID_NETWORK_TYPES: raise ValueError( @@ -83,6 +75,14 @@ def download_model( f"Output folder already exists: {output_folder}. Use --force to overwrite." ) + try: + from huggingface_hub import hf_hub_download, list_repo_files + except ImportError as exc: + raise ImportError( + "huggingface_hub is required for HuggingFace downloads. " + "Install it with: pip install lanfactory[hf]" + ) from exc + # Create output folder output_folder.mkdir(parents=True, exist_ok=True) From cc5769223a4782b275a3ea1698fb35f8b67d83b1 Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 17 Mar 2026 00:25:38 -0400 Subject: [PATCH 07/11] some more refinements --- tests/hf/test_download.py | 14 ++++++++++++++ tests/hf/test_upload.py | 11 +++++++++++ 2 files changed, 25 insertions(+) diff --git a/tests/hf/test_download.py b/tests/hf/test_download.py index 02c71f9..686c185 100644 --- a/tests/hf/test_download.py +++ b/tests/hf/test_download.py @@ -9,6 +9,15 @@ download_model, ) +try: + import huggingface_hub # noqa: F401 + + HAS_HF = True +except ImportError: + HAS_HF = False + +requires_hf = pytest.mark.skipif(not HAS_HF, reason="huggingface_hub not installed") + class TestDownloadModel: """Tests for download_model function.""" @@ -35,6 +44,7 @@ def test_raises_if_output_exists_without_force(self, tmp_path): force=False, ) + @requires_hf @patch("huggingface_hub.list_repo_files") @patch("huggingface_hub.hf_hub_download") def test_downloads_files_to_output_folder( @@ -67,6 +77,7 @@ def test_downloads_files_to_output_folder( # Check that download was called for correct files assert mock_download.call_count == 2 + @requires_hf @patch("huggingface_hub.list_repo_files") def test_raises_if_no_files_found(self, mock_list_files, tmp_path): """Test raises FileNotFoundError if no files match.""" @@ -84,6 +95,7 @@ def test_raises_if_no_files_found(self, mock_list_files, tmp_path): output_folder=output_folder, ) + @requires_hf @patch("huggingface_hub.list_repo_files") @patch("huggingface_hub.hf_hub_download") def test_applies_include_patterns(self, mock_download, mock_list_files, tmp_path): @@ -109,6 +121,7 @@ def test_applies_include_patterns(self, mock_download, mock_list_files, tmp_path # Only .onnx file should be downloaded assert mock_download.call_count == 1 + @requires_hf @patch("huggingface_hub.list_repo_files") @patch("huggingface_hub.hf_hub_download") def test_applies_exclude_patterns(self, mock_download, mock_list_files, tmp_path): @@ -134,6 +147,7 @@ def test_applies_exclude_patterns(self, mock_download, mock_list_files, tmp_path # Only .onnx file should be downloaded assert mock_download.call_count == 1 + @requires_hf @patch("huggingface_hub.list_repo_files") @patch("huggingface_hub.hf_hub_download") def test_force_overwrites_existing(self, mock_download, mock_list_files, tmp_path): diff --git a/tests/hf/test_upload.py b/tests/hf/test_upload.py index dad5e94..e454ebe 100644 --- a/tests/hf/test_upload.py +++ b/tests/hf/test_upload.py @@ -12,6 +12,15 @@ upload_model, ) +try: + import huggingface_hub # noqa: F401 + + HAS_HF = True +except ImportError: + HAS_HF = False + +requires_hf = pytest.mark.skipif(not HAS_HF, reason="huggingface_hub not installed") + class TestCollectFiles: """Tests for _collect_files function.""" @@ -118,6 +127,7 @@ def test_dry_run_does_not_upload(self, tmp_path): assert result is None + @requires_hf @patch("huggingface_hub.HfApi") @patch("huggingface_hub.create_repo") def test_creates_repo_when_requested( @@ -144,6 +154,7 @@ def test_creates_repo_when_requested( mock_create_repo.assert_called_once() + @requires_hf @patch("huggingface_hub.HfApi") @patch("huggingface_hub.create_repo") def test_uploads_to_correct_path(self, mock_create_repo, mock_api_class, tmp_path): From 134f57d4735b52093744fc0c5aec3ddf2c7162ed Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 18 Mar 2026 16:19:56 -0400 Subject: [PATCH 08/11] update ssms version --- pyproject.toml | 3 ++- tests/conftest.py | 44 ++++++++++++-------------------- tests/test_end_to_end_jax.py | 15 +++++------ tests/test_end_to_end_torch.py | 15 +++++------ tests/test_mlflow_integration.py | 40 ++++++++++++++++------------- 5 files changed, 52 insertions(+), 65 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3f353f5..7f4c33b 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ ] dependencies = [ - "ssm-simulators>=0.10.0", + "ssm-simulators>=0.12.2", "scipy>=1.15.2", "pandas>=2.2.3", "torch>=2.7.0", @@ -38,6 +38,7 @@ dependencies = [ "frozendict>=2.4.6", "onnx>=1.17.0", "matplotlib>=3.10.1", + "typer>=0.9.0", ] keywords = [ diff --git a/tests/conftest.py b/tests/conftest.py index 5b28fc7..1f1fcec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -113,26 +113,21 @@ def dummy_generator_config(model_selector): """Fixture providing a dummy model config for testing.""" def _dummy_generator_config(mode="random"): - # Initialize the generator config (for MLP LANs) simulator_param_mapping = True while simulator_param_mapping: - # TODO: #35 use this after ssms v1.0.0 release - # generator_config = ssms.config.get_default_generator_config("lan") - # and delete the line below - generator_config = deepcopy(ssms.config.data_generator_config["lan"]) - # Specify generative model (one from the list of included models mentioned above) + generator_config = ssms.config.get_default_generator_config("lan") generator_config["model"] = model_selector(mode=mode) - # Specify number of parameter sets to simulate - generator_config["n_parameter_sets"] = ( + generator_config["pipeline"]["n_parameter_sets"] = ( TEST_GENERATOR_CONSTANTS.N_PARAMETER_SETS ) - # Specify how many samples a simulation run should entail - generator_config["n_samples"] = TEST_GENERATOR_CONSTANTS.N_SAMPLES - # Specify folder in which to save generated data - generator_config["output_folder"] = os.path.join( + generator_config["pipeline"]["n_cpus"] = 1 + generator_config["simulator"]["n_samples"] = ( + TEST_GENERATOR_CONSTANTS.N_SAMPLES + ) + generator_config["output"]["folder"] = os.path.join( TEST_GENERATOR_CONSTANTS.OUT_FOLDER, str(uuid.uuid4()) ) - generator_config["n_training_samples_by_parameter_set"] = ( + generator_config["training"]["n_samples_per_param"] = ( TEST_GENERATOR_CONSTANTS.N_SAMPLES_BY_PARAMETER_SET ) model_config = deepcopy(ssms.config.model_config[generator_config["model"]]) @@ -153,30 +148,23 @@ def _dummy_generator_config(mode="random"): def dummy_generator_config_simple_two_choices(model_selector): """Fixture providing a dummy model config for testing.""" - # TODO: replace use of ssms.config.data_generator_config with ssms.config.get_default_generator_config - # after ssms v1.0.0 release def _dummy_generator_config_simple_two_choices(mode="random"): two_choices = False simulator_param_mapping = True while (not two_choices) or (simulator_param_mapping): - # Initialize the generator config (for MLP LANs) - # TODO: use this after ssms v1.0.0 release - # generator_config = ssms.config.get_default_generator_config("lan") - # and delete the line below - generator_config = deepcopy(ssms.config.data_generator_config["lan"]) - # Specify generative model (one from the list of included models mentioned above) + generator_config = ssms.config.get_default_generator_config("lan") generator_config["model"] = model_selector(mode=mode) - # Specify number of parameter sets to simulate - generator_config["n_parameter_sets"] = ( + generator_config["pipeline"]["n_parameter_sets"] = ( TEST_GENERATOR_CONSTANTS.N_PARAMETER_SETS ) - # Specify how many samples a simulation run should entail - generator_config["n_samples"] = TEST_GENERATOR_CONSTANTS.N_SAMPLES - # Specify folder in which to save generated data - generator_config["output_folder"] = os.path.join( + generator_config["pipeline"]["n_cpus"] = 1 + generator_config["simulator"]["n_samples"] = ( + TEST_GENERATOR_CONSTANTS.N_SAMPLES + ) + generator_config["output"]["folder"] = os.path.join( TEST_GENERATOR_CONSTANTS.OUT_FOLDER, str(uuid.uuid4()) ) - generator_config["n_training_samples_by_parameter_set"] = ( + generator_config["training"]["n_samples_per_param"] = ( TEST_GENERATOR_CONSTANTS.N_SAMPLES_BY_PARAMETER_SET ) model_config = deepcopy(ssms.config.model_config[generator_config["model"]]) diff --git a/tests/test_end_to_end_jax.py b/tests/test_end_to_end_jax.py index 34b3a62..9a5a65d 100644 --- a/tests/test_end_to_end_jax.py +++ b/tests/test_end_to_end_jax.py @@ -21,23 +21,20 @@ def dummy_training_data_files(generator_config, model_config, save=True): """Fixture providing a dummy training data for testing.""" - os.makedirs(generator_config["output_folder"], exist_ok=True) + output_folder = generator_config["output"]["folder"] + os.makedirs(output_folder, exist_ok=True) for i in range(TEST_GENERATOR_CONSTANTS.N_DATA_FILES): - # log progress logger.info( "Generating training data for file %d of %d", i + 1, TEST_GENERATOR_CONSTANTS.N_DATA_FILES, ) - my_dataset_generator = ssms.dataset_generators.lan_mlp.data_generator( - generator_config=generator_config, model_config=model_config + my_dataset_generator = ssms.dataset_generators.lan_mlp.TrainingDataGenerator( + config=generator_config, model_config=model_config ) - _ = my_dataset_generator.generate_data_training_uniform(save=save) + _ = my_dataset_generator.generate_data_training(save=save) - return [ - os.path.join(generator_config["output_folder"], file_) - for file_ in os.listdir(generator_config["output_folder"]) - ] + return [os.path.join(output_folder, file_) for file_ in os.listdir(output_folder)] @pytest.mark.flaky(reruns=2) diff --git a/tests/test_end_to_end_torch.py b/tests/test_end_to_end_torch.py index cbd7ff6..22a82ae 100644 --- a/tests/test_end_to_end_torch.py +++ b/tests/test_end_to_end_torch.py @@ -19,23 +19,20 @@ def dummy_training_data_files(generator_config, model_config, save=True): """Fixture providing a dummy training data for testing.""" - os.makedirs(generator_config["output_folder"], exist_ok=True) + output_folder = generator_config["output"]["folder"] + os.makedirs(output_folder, exist_ok=True) for i in range(TEST_GENERATOR_CONSTANTS.N_DATA_FILES): - # log progress logger.info( "Generating training data for file %d of %d", i + 1, TEST_GENERATOR_CONSTANTS.N_DATA_FILES, ) - my_dataset_generator = ssms.dataset_generators.lan_mlp.data_generator( - generator_config=generator_config, model_config=model_config + my_dataset_generator = ssms.dataset_generators.lan_mlp.TrainingDataGenerator( + config=generator_config, model_config=model_config ) - _ = my_dataset_generator.generate_data_training_uniform(save=save) + _ = my_dataset_generator.generate_data_training(save=save) - return [ - os.path.join(generator_config["output_folder"], file_) - for file_ in os.listdir(generator_config["output_folder"]) - ] + return [os.path.join(output_folder, file_) for file_ in os.listdir(output_folder)] @pytest.mark.flaky(reruns=2) diff --git a/tests/test_mlflow_integration.py b/tests/test_mlflow_integration.py index bedc7f3..c2f348a 100644 --- a/tests/test_mlflow_integration.py +++ b/tests/test_mlflow_integration.py @@ -322,17 +322,19 @@ def test_jax_trainer_mlflow_logging( # Set small output folder output_folder = tmp_path / "data" output_folder.mkdir() - generator_config["output_folder"] = str(output_folder) - generator_config["n_parameter_sets"] = 20 # Must be >= n_subruns (default 10) - generator_config["n_subruns"] = 2 # Reduce subruns for faster test + generator_config["output"]["folder"] = str(output_folder) + generator_config["pipeline"]["n_parameter_sets"] = ( + 20 # Must be >= n_subruns (default 10) + ) + generator_config["pipeline"]["n_subruns"] = 2 # Reduce subruns for faster test # Generate one file import ssms - gen = ssms.dataset_generators.lan_mlp.data_generator( - generator_config=generator_config, model_config=model_config + gen = ssms.dataset_generators.lan_mlp.TrainingDataGenerator( + config=generator_config, model_config=model_config ) - gen.generate_data_training_uniform(save=True, verbose=False) + gen.generate_data_training(save=True, verbose=False) # Set up training (dummy_network_train_config_lan is already a dict, not a function) network_config = dummy_network_train_config_lan["network_config"] @@ -417,17 +419,19 @@ def test_pytorch_trainer_mlflow_logging( # Set small output folder output_folder = tmp_path / "data" output_folder.mkdir() - generator_config["output_folder"] = str(output_folder) - generator_config["n_parameter_sets"] = 20 # Must be >= n_subruns (default 10) - generator_config["n_subruns"] = 2 # Reduce subruns for faster test + generator_config["output"]["folder"] = str(output_folder) + generator_config["pipeline"]["n_parameter_sets"] = ( + 20 # Must be >= n_subruns (default 10) + ) + generator_config["pipeline"]["n_subruns"] = 2 # Reduce subruns for faster test # Generate one file import ssms - gen = ssms.dataset_generators.lan_mlp.data_generator( - generator_config=generator_config, model_config=model_config + gen = ssms.dataset_generators.lan_mlp.TrainingDataGenerator( + config=generator_config, model_config=model_config ) - gen.generate_data_training_uniform(save=True, verbose=False) + gen.generate_data_training(save=True, verbose=False) # Set up training (dummy_network_train_config_lan is already a dict, not a function) network_config = dummy_network_train_config_lan["network_config"] @@ -649,16 +653,16 @@ def test_trainer_without_mlflow( output_folder = tmp_path / "data" output_folder.mkdir() - generator_config["output_folder"] = str(output_folder) - generator_config["n_parameter_sets"] = 20 - generator_config["n_subruns"] = 2 + generator_config["output"]["folder"] = str(output_folder) + generator_config["pipeline"]["n_parameter_sets"] = 20 + generator_config["pipeline"]["n_subruns"] = 2 import ssms - gen = ssms.dataset_generators.lan_mlp.data_generator( - generator_config=generator_config, model_config=model_config + gen = ssms.dataset_generators.lan_mlp.TrainingDataGenerator( + config=generator_config, model_config=model_config ) - gen.generate_data_training_uniform(save=True, verbose=False) + gen.generate_data_training(save=True, verbose=False) # Set up training network_config = dummy_network_train_config_lan["network_config"] From ac03ce2021aab84cb7607b56725e2ee5a8033e67 Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 18 Mar 2026 22:37:56 -0400 Subject: [PATCH 09/11] improve test coverage --- src/lanfactory/hf/download.py | 2 +- src/lanfactory/hf/upload.py | 2 +- tests/hf/test_model_card.py | 82 ++++++++++++++++++++++ tests/test_torch_mlp.py | 125 ++++++++++++++++++++++++++++++++++ 4 files changed, 209 insertions(+), 2 deletions(-) diff --git a/src/lanfactory/hf/download.py b/src/lanfactory/hf/download.py index 3bf8b5d..dd1e549 100644 --- a/src/lanfactory/hf/download.py +++ b/src/lanfactory/hf/download.py @@ -75,7 +75,7 @@ def download_model( f"Output folder already exists: {output_folder}. Use --force to overwrite." ) - try: + try: # pragma: no cover (requires huggingface_hub optional dependency) from huggingface_hub import hf_hub_download, list_repo_files except ImportError as exc: raise ImportError( diff --git a/src/lanfactory/hf/upload.py b/src/lanfactory/hf/upload.py index a8fd9ee..a5dad3a 100644 --- a/src/lanfactory/hf/upload.py +++ b/src/lanfactory/hf/upload.py @@ -127,7 +127,7 @@ def upload_model( print(f" - {f.name}") return None - try: + try: # pragma: no cover (requires huggingface_hub optional dependency) from huggingface_hub import HfApi, create_repo as hf_create_repo except ImportError as exc: raise ImportError( diff --git a/tests/hf/test_model_card.py b/tests/hf/test_model_card.py index 4680819..37b94bf 100644 --- a/tests/hf/test_model_card.py +++ b/tests/hf/test_model_card.py @@ -194,6 +194,88 @@ def test_custom_usage_example(self): assert "custom_code_here()" in readme + def test_includes_training_loss_field(self): + """Test that training loss field is rendered when provided.""" + config = ModelCardConfig( + training={ + "epochs": 10, + "optimizer": "adam", + "learning_rate": 0.001, + "loss": "huber", + } + ) + + readme = generate_readme(config) + + assert "**Loss:** huber" in readme + + def test_partial_architecture_missing_fields(self): + """Test architecture section with only layer_sizes (no network_type or activations).""" + config = ModelCardConfig( + architecture={ + "layer_sizes": [100, 100, 1], + } + ) + + readme = generate_readme(config) + + assert "## Architecture" in readme + assert "[100, 100, 1]" in readme + assert "Network Type" not in readme + assert "Activations" not in readme + + +class TestLoadModelCardYamlPickleIntegration: + """Tests for _fill_from_pickle_configs integration.""" + + def test_fills_training_from_pickle(self, tmp_path): + """Test that training config is filled from train_config pickle.""" + yaml_content = {"title": "Test Model"} + with open(tmp_path / "model_card.yaml", "w") as f: + yaml.dump(yaml_content, f) + + train_config = { + "n_epochs": 20, + "optimizer": "adam", + "learning_rate": 0.001, + "loss": "huber", + } + with open(tmp_path / "test_train_config.pickle", "wb") as f: + pickle.dump(train_config, f) + + config = load_model_card_yaml(tmp_path) + + assert config.training is not None + assert config.training["epochs"] == 20 + assert config.training["optimizer"] == "adam" + assert config.training["loss"] == "huber" + + def test_handles_corrupt_network_pickle(self, tmp_path): + """Test graceful fallback when network config pickle is corrupt.""" + yaml_content = {"title": "Test Model"} + with open(tmp_path / "model_card.yaml", "w") as f: + yaml.dump(yaml_content, f) + + with open(tmp_path / "bad_network_config.pickle", "wb") as f: + f.write(b"not a valid pickle") + + config = load_model_card_yaml(tmp_path) + + assert config.architecture is None + + def test_handles_corrupt_train_pickle(self, tmp_path): + """Test graceful fallback when train config pickle is corrupt.""" + yaml_content = {"title": "Test Model"} + with open(tmp_path / "model_card.yaml", "w") as f: + yaml.dump(yaml_content, f) + + with open(tmp_path / "bad_train_config.pickle", "wb") as f: + f.write(b"not a valid pickle") + + config = load_model_card_yaml(tmp_path) + + assert config.training is None + class TestWriteReadme: """Tests for write_readme function.""" diff --git a/tests/test_torch_mlp.py b/tests/test_torch_mlp.py index 18d93d9..da38fe9 100644 --- a/tests/test_torch_mlp.py +++ b/tests/test_torch_mlp.py @@ -1020,3 +1020,128 @@ def test_model_trainer_torch_mlp_with_none_train_config(create_mock_data_files): train_dl=train_dl, valid_dl=train_dl, ) + + +# --- Tests for helper / factory functions --- + + +def test_make_dataloader(create_mock_data_files): + """Test make_dataloader creates a working DataLoader.""" + from lanfactory.trainers.torch_mlp import make_dataloader + + file_list = create_mock_data_files(n_files=2) + + dl = make_dataloader( + file_ids=file_list, + batch_size=100, + network_type="lan", + ) + + assert dl is not None + assert dl.dataset.input_dim == 6 + assert dl.dataset.batch_size == 100 + + +def test_make_dataloader_cpn_no_default_lower_bound(create_mock_data_files): + """Test make_dataloader with cpn does not auto-set label_lower_bound.""" + from lanfactory.trainers.torch_mlp import make_dataloader + + file_list = create_mock_data_files(n_files=1) + + data = { + "cpn_data": np.random.randn(1000, 4).astype(np.float32), + "cpn_labels": np.random.randn(1000).astype(np.float32), + } + with open(file_list[0], "wb") as f: + pickle.dump(data, f) + + dl = make_dataloader( + file_ids=file_list, + batch_size=100, + network_type="cpn", + ) + + assert dl.dataset.input_dim == 4 + + +def test_make_train_valid_dataloaders(create_mock_data_files): + """Test make_train_valid_dataloaders splits files correctly.""" + from lanfactory.trainers.torch_mlp import make_train_valid_dataloaders + + file_list = create_mock_data_files(n_files=4) + + train_dl, valid_dl, input_dim = make_train_valid_dataloaders( + file_ids=file_list, + batch_size=100, + network_type="lan", + train_val_split=0.5, + shuffle_files=False, + ) + + assert input_dim == 6 + assert len(train_dl.dataset.file_ids) == 2 + assert len(valid_dl.dataset.file_ids) == 2 + + +def test_make_train_valid_dataloaders_raises_no_train_files(create_mock_data_files): + """Test raises ValueError when split leaves no training files.""" + from lanfactory.trainers.torch_mlp import make_train_valid_dataloaders + + file_list = create_mock_data_files(n_files=1) + + with pytest.raises(ValueError, match="No training files after split"): + make_train_valid_dataloaders( + file_ids=file_list, + batch_size=100, + train_val_split=0.0, + ) + + +def test_make_train_valid_dataloaders_raises_no_valid_files(create_mock_data_files): + """Test raises ValueError when split leaves no validation files.""" + from lanfactory.trainers.torch_mlp import make_train_valid_dataloaders + + file_list = create_mock_data_files(n_files=1) + + with pytest.raises(ValueError, match="No validation files after split"): + make_train_valid_dataloaders( + file_ids=file_list, + batch_size=100, + train_val_split=1.0, + ) + + +def test_torch_mlp_factory_with_dict(): + """Test TorchMLPFactory with dict config.""" + from lanfactory.trainers.torch_mlp import TorchMLPFactory, TorchMLP + + network_config = { + "layer_sizes": [10, 10, 1], + "activations": ["tanh", "tanh", "linear"], + "train_output_type": "logprob", + } + + model = TorchMLPFactory(network_config=network_config, input_dim=6) + + assert isinstance(model, TorchMLP) + assert model.input_shape == 6 + + +def test_torch_mlp_factory_with_pickle_path(tmp_path): + """Test TorchMLPFactory with path to pickled config.""" + from lanfactory.trainers.torch_mlp import TorchMLPFactory, TorchMLP + + network_config = { + "layer_sizes": [10, 10, 1], + "activations": ["tanh", "tanh", "linear"], + "train_output_type": "logprob", + } + + config_path = tmp_path / "network_config.pickle" + with open(config_path, "wb") as f: + pickle.dump(network_config, f) + + model = TorchMLPFactory(network_config=str(config_path), input_dim=6) + + assert isinstance(model, TorchMLP) + assert model.input_shape == 6 From 4a7d97a99024855d93dee9e1f041c759c7bf9012 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 19 Mar 2026 14:57:46 -0400 Subject: [PATCH 10/11] improve tests --- src/lanfactory/hf/download.py | 32 +++++++++++++++------ src/lanfactory/hf/upload.py | 36 +++++++++++++++++------ tests/hf/test_model_card.py | 54 +++++++++++++++++++++++++++++++++++ tests/hf/test_upload.py | 14 +++++++++ 4 files changed, 120 insertions(+), 16 deletions(-) diff --git a/src/lanfactory/hf/download.py b/src/lanfactory/hf/download.py index dd1e549..52bb17c 100644 --- a/src/lanfactory/hf/download.py +++ b/src/lanfactory/hf/download.py @@ -75,7 +75,30 @@ def download_model( f"Output folder already exists: {output_folder}. Use --force to overwrite." ) - try: # pragma: no cover (requires huggingface_hub optional dependency) + return _download_model_hf( # pragma: no cover + network_type=network_type, + model_name=model_name, + output_folder=output_folder, + repo_id=repo_id, + revision=revision, + include_patterns=include_patterns, + exclude_patterns=exclude_patterns, + token=token, + ) + + +def _download_model_hf( # pragma: no cover + network_type: str, + model_name: str, + output_folder: Path, + repo_id: str, + revision: str | None, + include_patterns: list[str] | None, + exclude_patterns: list[str] | None, + token: str | None, +) -> Path: + """HF-dependent implementation of download_model.""" + try: from huggingface_hub import hf_hub_download, list_repo_files except ImportError as exc: raise ImportError( @@ -83,13 +106,10 @@ def download_model( "Install it with: pip install lanfactory[hf]" ) from exc - # Create output folder output_folder.mkdir(parents=True, exist_ok=True) - # Build the path prefix for this model path_prefix = f"{network_type}/{model_name}/" - # List files in the repository try: all_files = list_repo_files( repo_id=repo_id, @@ -100,7 +120,6 @@ def download_model( logger.error(f"Failed to list repository files: {e}") raise - # Filter files by path prefix model_files = [f for f in all_files if f.startswith(path_prefix)] if not model_files: @@ -109,7 +128,6 @@ def download_model( f"Available paths: {set(f.split('/')[0] for f in all_files if '/' in f)}" ) - # Apply include/exclude patterns if include_patterns: filtered_files = [] for f in model_files: @@ -140,7 +158,6 @@ def download_model( logger.info(f"Downloading {len(model_files)} files from {repo_id}/{path_prefix}") - # Download each file downloaded_files = [] for file_path in model_files: filename = Path(file_path).name @@ -154,7 +171,6 @@ def download_model( token=token, ) - # Copy to output folder dest_path = output_folder / filename shutil.copy2(local_path, dest_path) downloaded_files.append(dest_path) diff --git a/src/lanfactory/hf/upload.py b/src/lanfactory/hf/upload.py index a5dad3a..dbf5680 100644 --- a/src/lanfactory/hf/upload.py +++ b/src/lanfactory/hf/upload.py @@ -127,7 +127,34 @@ def upload_model( print(f" - {f.name}") return None - try: # pragma: no cover (requires huggingface_hub optional dependency) + return _upload_to_hf( # pragma: no cover + model_folder=model_folder, + model_name=model_name, + files_to_upload=files_to_upload, + path_in_repo=path_in_repo, + repo_id=repo_id, + commit_message=commit_message, + private=private, + create_repo=create_repo, + revision=revision, + token=token, + ) + + +def _upload_to_hf( # pragma: no cover + model_folder: Path, + model_name: str, + files_to_upload: list[Path], + path_in_repo: str, + repo_id: str, + commit_message: str, + private: bool, + create_repo: bool, + revision: str | None, + token: str | None, +) -> str: + """HF-dependent implementation of upload_model.""" + try: from huggingface_hub import HfApi, create_repo as hf_create_repo except ImportError as exc: raise ImportError( @@ -135,10 +162,8 @@ def upload_model( "Install it with: pip install lanfactory[hf]" ) from exc - # Initialize API api = HfApi(token=token) - # Create repo if requested if create_repo: try: hf_create_repo( @@ -153,23 +178,19 @@ def upload_model( logger.error(f"Failed to create repository: {e}") raise - # Generate README.md from model_card.yaml from lanfactory.hf.model_card import load_model_card_yaml, write_readme config = load_model_card_yaml(model_folder) readme_path = write_readme(model_folder, config, model_name) files_to_upload.append(readme_path) - # Create a temporary directory with files organized for upload with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) - # Copy files to temp directory for file_path in files_to_upload: dest = tmp_path / file_path.name shutil.copy2(file_path, dest) - # Upload the folder try: api.upload_folder( folder_path=str(tmp_path), @@ -183,7 +204,6 @@ def upload_model( logger.error(f"Upload failed: {e}") raise - # Construct URL url = f"https://huggingface.co/{repo_id}/tree/main/{path_in_repo}" logger.info(f"Upload successful: {url}") print("\nUpload successful!") diff --git a/tests/hf/test_model_card.py b/tests/hf/test_model_card.py index 37b94bf..82e3525 100644 --- a/tests/hf/test_model_card.py +++ b/tests/hf/test_model_card.py @@ -224,10 +224,64 @@ def test_partial_architecture_missing_fields(self): assert "Network Type" not in readme assert "Activations" not in readme + def test_architecture_all_fields_empty(self): + """Test architecture section where all fields are None/falsy.""" + config = ModelCardConfig( + architecture={ + "network_type": None, + "layer_sizes": None, + "activations": None, + } + ) + + readme = generate_readme(config) + + assert "## Architecture" in readme + assert "Network Type" not in readme + assert "Layer Sizes" not in readme + assert "Activations" not in readme + + def test_training_all_fields_empty(self): + """Test training section where all fields are None/falsy.""" + config = ModelCardConfig( + training={ + "epochs": None, + "optimizer": None, + "learning_rate": None, + "loss": None, + } + ) + + readme = generate_readme(config) + + assert "## Training" in readme + assert "Epochs" not in readme + assert "Optimizer" not in readme + assert "Learning Rate" not in readme + assert "Loss" not in readme + class TestLoadModelCardYamlPickleIntegration: """Tests for _fill_from_pickle_configs integration.""" + def test_skips_pickle_when_yaml_has_data(self, tmp_path): + """Test that pickle loading is skipped when YAML already has architecture and training.""" + yaml_content = { + "title": "Pre-filled Model", + "architecture": {"layer_sizes": [50, 50, 1], "network_type": "lan"}, + "training": {"epochs": 5, "optimizer": "sgd"}, + } + with open(tmp_path / "model_card.yaml", "w") as f: + yaml.dump(yaml_content, f) + + config = load_model_card_yaml(tmp_path) + + assert config.architecture == { + "layer_sizes": [50, 50, 1], + "network_type": "lan", + } + assert config.training == {"epochs": 5, "optimizer": "sgd"} + def test_fills_training_from_pickle(self, tmp_path): """Test that training config is filled from train_config pickle.""" yaml_content = {"title": "Test Model"} diff --git a/tests/hf/test_upload.py b/tests/hf/test_upload.py index e454ebe..6eea1f6 100644 --- a/tests/hf/test_upload.py +++ b/tests/hf/test_upload.py @@ -104,6 +104,20 @@ def test_raises_if_model_card_missing(self, tmp_path): model_name="ddm", ) + def test_raises_if_no_matching_files(self, tmp_path): + """Test raises FileNotFoundError when no files match patterns.""" + yaml_content = {"title": "Test Model"} + with open(tmp_path / "model_card.yaml", "w") as f: + yaml.dump(yaml_content, f) + + with pytest.raises(FileNotFoundError, match="No files matching patterns"): + upload_model( + model_folder=tmp_path, + network_type="lan", + model_name="ddm", + include_patterns=["*.nonexistent"], + ) + def test_dry_run_does_not_upload(self, tmp_path): """Test dry_run shows files but doesn't upload.""" # Create model_card.yaml From 62952062192c2edf765628698bf593411e80da9a Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 19 Mar 2026 15:50:53 -0400 Subject: [PATCH 11/11] address copilot --- src/lanfactory/hf/upload.py | 2 +- src/lanfactory/trainers/jax_mlp.py | 13 ++++++++++--- src/lanfactory/trainers/torch_mlp.py | 2 +- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/lanfactory/hf/upload.py b/src/lanfactory/hf/upload.py index dbf5680..f85f891 100644 --- a/src/lanfactory/hf/upload.py +++ b/src/lanfactory/hf/upload.py @@ -204,7 +204,7 @@ def _upload_to_hf( # pragma: no cover logger.error(f"Upload failed: {e}") raise - url = f"https://huggingface.co/{repo_id}/tree/main/{path_in_repo}" + url = f"https://huggingface.co/{repo_id}/tree/{revision or 'main'}/{path_in_repo}" logger.info(f"Upload successful: {url}") print("\nUpload successful!") print(f"View your model at: {url}") diff --git a/src/lanfactory/trainers/jax_mlp.py b/src/lanfactory/trainers/jax_mlp.py index d2df40a..3c3968c 100755 --- a/src/lanfactory/trainers/jax_mlp.py +++ b/src/lanfactory/trainers/jax_mlp.py @@ -24,11 +24,13 @@ print("mlflow not available") -def JaxMLPFactory(network_config: dict | str = {}, train: bool = True) -> "JaxMLP": +def JaxMLPFactory( + network_config: dict | str | None = None, train: bool = True +) -> "JaxMLP": """Factory function to create a JaxMLP object. Arguments --------- - network_config (dict | str): + network_config (dict | str | None): Dictionary containing the network configuration or path to pickled config. train (bool): Whether the model should be trained or not. @@ -36,9 +38,14 @@ def JaxMLPFactory(network_config: dict | str = {}, train: bool = True) -> "JaxML ------- JaxMLP class initialized with the correct network configuration. """ + if network_config is None: + raise ValueError( + "network_config must be provided (dict or path to pickle file)" + ) if isinstance(network_config, str): - network_config_internal = pickle.load(open(network_config, "rb")) + with open(network_config, "rb") as f: + network_config_internal = pickle.load(f) elif isinstance(network_config, dict): network_config_internal = network_config else: diff --git a/src/lanfactory/trainers/torch_mlp.py b/src/lanfactory/trainers/torch_mlp.py index a128709..35bdc30 100755 --- a/src/lanfactory/trainers/torch_mlp.py +++ b/src/lanfactory/trainers/torch_mlp.py @@ -294,7 +294,7 @@ def make_train_valid_dataloaders( batch_size=batch_size, network_type=network_type, label_lower_bound=label_lower_bound, - shuffle=True, + shuffle=False, num_workers=num_workers, pin_memory=pin_memory, )