Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/hf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:::lanfactory.hf
75 changes: 31 additions & 44 deletions docs/basic_tutorial/basic_tutorial_cpn_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -166,60 +166,40 @@
"source": [
"#### Prepare for Training\n",
"\n",
"Next we set up dataloaders for training with pytorch. The `LANfactory` uses custom dataloaders, taking into account particularities of the expected training data.\n",
"Specifically, we expect to receive a bunch of training data files (the present example generates only one), where each file hosts a large number of training examples. \n",
"So we want to define a dataloader which spits out batches from data with a specific training data file, and keeps checking when to load in a new file. \n",
"The way this is implemented here, is via the `DatasetTorch` class in `lanfactory.trainers`, which inherits from `torch.utils.data.Dataset` and prespecifies a `batch_size`. Finally this is supplied to a [`DataLoader`](https://pytorch.org/docs/stable/data.html), for which we keep the `batch_size` argument at 0.\n",
"Next we set up dataloaders for training with pytorch. The `LANfactory` provides convenient helper functions for this.\n",
"\n",
"The `DatasetTorch` class is then called as an iterator via the DataLoader and takes care of batching as well as file loading internally. \n",
"The `make_train_valid_dataloaders` function handles:\n",
"- Splitting your data files into training and validation sets\n",
"- Creating the appropriate `DatasetTorch` objects\n",
"- Wrapping them in PyTorch `DataLoader` objects with sensible defaults\n",
"\n",
"You may choose your own way of defining the `DataLoader` classes, downstream you are simply expected to supply one."
"Under the hood, this uses the `DatasetTorch` class which handles batching and file loading internally."
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# MAKE DATALOADERS\n",
"TRAINING_TYPE = \"cpn\"\n",
"\n",
"# List of datafiles (here only one)\n",
"# List of datafiles\n",
"folder_ = Path(\"torch_nb_data/\") / \"cpn_mlp\" / \"training_data\"\n",
"file_list_ = list(folder_.glob(\"*\"))\n",
"\n",
"# Training dataset\n",
"torch_training_dataset = lanfactory.trainers.DatasetTorch(\n",
" file_ids=file_list_,\n",
" batch_size=BATCH_SIZE,\n",
" features_key=f\"{TRAINING_TYPE}_data\",\n",
" label_key=f\"{TRAINING_TYPE}_labels\",\n",
")\n",
"\n",
"torch_training_dataloader = torch.utils.data.DataLoader(\n",
" torch_training_dataset,\n",
" shuffle=True,\n",
" batch_size=None,\n",
" num_workers=1,\n",
" pin_memory=True,\n",
")\n",
"file_list_ = list(folder_.glob(\"*.pickle\"))\n",
"\n",
"# Validation dataset\n",
"torch_validation_dataset = lanfactory.trainers.DatasetTorch(\n",
"# Create train and validation dataloaders with a single function call\n",
"torch_training_dataloader, torch_validation_dataloader, input_dim = lanfactory.trainers.make_train_valid_dataloaders(\n",
" file_ids=file_list_,\n",
" batch_size=BATCH_SIZE,\n",
" features_key=f\"{TRAINING_TYPE}_data\",\n",
" label_key=f\"{TRAINING_TYPE}_labels\",\n",
" network_type=\"cpn\",\n",
" train_val_split=0.9,\n",
")\n",
"\n",
"torch_validation_dataloader = torch.utils.data.DataLoader(\n",
" torch_validation_dataset,\n",
" shuffle=True,\n",
" batch_size=None,\n",
" num_workers=1,\n",
" pin_memory=True,\n",
")"
"print(f\"Training batches: {len(torch_training_dataloader)}\")\n",
"print(f\"Validation batches: {len(torch_validation_dataloader)}\")\n",
"print(f\"Input dimension: {input_dim}\")"
]
},
{
Expand Down Expand Up @@ -275,7 +255,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -302,20 +282,26 @@
],
"source": [
"# LOAD NETWORK\n",
"net = lanfactory.trainers.TorchMLP(\n",
"# Option 1: Using the factory function (recommended)\n",
"net = lanfactory.trainers.TorchMLPFactory(\n",
" network_config=deepcopy(network_config),\n",
" input_shape=torch_training_dataset.input_dim,\n",
" save_folder=Path(\"torch_nb_data/torch_models\") / TRAINING_TYPE / MODEL,\n",
" generative_model_id=MODEL,\n",
" input_dim=input_dim,\n",
" network_type=\"cpn\",\n",
")\n",
"\n",
"# Option 2: Direct instantiation (also works)\n",
"# net = lanfactory.trainers.TorchMLP(\n",
"# network_config=deepcopy(network_config),\n",
"# input_shape=input_dim,\n",
"# network_type=\"cpn\",\n",
"# )\n",
"\n",
"# SAVE CONFIGS\n",
"lanfactory.utils.save_configs(\n",
" model_id=MODEL + \"_torch_\",\n",
" save_folder=Path(\"torch_nb_data/torch_models\") / TRAINING_TYPE / MODEL,\n",
" network_config=network_config,\n",
" train_config=train_config,\n",
" allow_abs_path_folder_generation=True,\n",
")"
]
},
Expand Down Expand Up @@ -465,7 +451,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -482,10 +468,11 @@
"model_path = Path(\"torch_nb_data/torch_models\") / TRAINING_TYPE / MODEL\n",
"network_file_path = next(model_path.glob(\"*state_dict*\"))\n",
"\n",
"# LoadTorchMLPInfer automatically calls eval() for inference mode\n",
"network = lanfactory.trainers.LoadTorchMLPInfer(\n",
" model_file_path=network_file_path,\n",
" network_config=network_config,\n",
" input_dim=torch_training_dataset.input_dim,\n",
" input_dim=input_dim,\n",
" network_type=TRAINING_TYPE,\n",
")"
]
Expand Down
68 changes: 20 additions & 48 deletions docs/basic_tutorial/basic_tutorial_lan_jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -192,66 +192,38 @@
"source": [
"#### Prepare for Training\n",
"\n",
"Next we set up dataloaders for training with pytorch. The `LANfactory` uses custom dataloaders, taking into account particularities of the expected training data.\n",
"Specifically, we expect to receive a bunch of training data files (the present example generates only one), where each file hosts a large number of training examples. \n",
"So we want to define a dataloader which spits out batches from data with a specific training data file, and keeps checking when to load in a new file. \n",
"The way this is implemented here, is via the `DatasetTorch` class in `lanfactory.trainers`, which inherits from `torch.utils.data.Dataset` and prespecifies a `batch_size`. Finally this is supplied to a [`DataLoader`](https://pytorch.org/docs/stable/data.html), for which we keep the `batch_size` argument at 0.\n",
"Next we set up dataloaders for training. The `LANfactory` provides convenient helper functions for this.\n",
"\n",
"The `DatasetTorch` class is then called as an iterator via the DataLoader and takes care of batching as well as file loading internally. \n",
"The `make_train_valid_dataloaders` function handles:\n",
"- Splitting your data files into training and validation sets\n",
"- Creating the appropriate `DatasetTorch` objects\n",
"- Wrapping them in PyTorch `DataLoader` objects with sensible defaults\n",
"\n",
"You may choose your own way of defining the `DataLoader` classes, downstream you are simply expected to supply one."
"The data is returned as numpy arrays, which JAX handles seamlessly."
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# folder_ = OUT_FOLDER\n",
"file_list_ = [os.path.join(OUT_FOLDER, file_) for file_ in os.listdir(OUT_FOLDER)]\n",
"# MAKE DATALOADERS\n",
"\n",
"# List of datafiles\n",
"file_list_ = list(OUT_FOLDER.glob(\"*.pickle\"))\n",
"\n",
"# INDEPENDENT TESTS OF DATALOADERS\n",
"# Training dataset\n",
"jax_training_dataset = lanfactory.trainers.DatasetTorch(\n",
"# Create train and validation dataloaders with a single function call\n",
"jax_training_dataloader, jax_validation_dataloader, input_dim = lanfactory.trainers.make_train_valid_dataloaders(\n",
" file_ids=file_list_,\n",
" batch_size=(\n",
" train_config[DEVICE + \"_batch_size\"]\n",
" if torch.cuda.is_available()\n",
" else train_config[DEVICE + \"_batch_size\"]\n",
" ),\n",
" label_lower_bound=np.log(1e-10),\n",
" features_key=\"lan_data\",\n",
" label_key=\"lan_labels\",\n",
" out_framework=\"jax\",\n",
" batch_size=train_config[\"cpu_batch_size\"],\n",
" network_type=\"lan\",\n",
" train_val_split=0.9,\n",
")\n",
"\n",
"jax_training_dataloader = torch.utils.data.DataLoader(\n",
" jax_training_dataset, shuffle=True, batch_size=None, num_workers=1, pin_memory=True\n",
")\n",
"\n",
"# Validation dataset\n",
"jax_validation_dataset = lanfactory.trainers.DatasetTorch(\n",
" file_ids=file_list_,\n",
" batch_size=(\n",
" train_config[DEVICE + \"_batch_size\"]\n",
" if torch.cuda.is_available()\n",
" else train_config[DEVICE + \"_batch_size\"]\n",
" ),\n",
" label_lower_bound=np.log(1e-10),\n",
" features_key=\"lan_data\",\n",
" label_key=\"lan_labels\",\n",
" out_framework=\"jax\",\n",
")\n",
"\n",
"jax_validation_dataloader = torch.utils.data.DataLoader(\n",
" jax_validation_dataset,\n",
" shuffle=True,\n",
" batch_size=None,\n",
" num_workers=1,\n",
" pin_memory=True,\n",
")"
"print(f\"Training batches: {len(jax_training_dataloader)}\")\n",
"print(f\"Validation batches: {len(jax_validation_dataloader)}\")\n",
"print(f\"Input dimension: {input_dim}\")"
]
},
{
Expand All @@ -269,7 +241,7 @@
"source": [
"# LOAD NETWORK\n",
"# Test properties of network\n",
"jax_net = lanfactory.trainers.MLPJaxFactory(network_config=network_config, train=True)\n",
"jax_net = lanfactory.trainers.JaxMLPFactory(network_config=network_config, train=True)\n",
"\n",
"# Save model config\n",
"# model_folder = os.path.join(\"data\", \"jax_models\", MODEL)\n",
Expand Down Expand Up @@ -386,7 +358,7 @@
"# Loaded Net\n",
"# Test passing network config as path and as object\n",
"\n",
"jax_infer = lanfactory.trainers.MLPJaxFactory(\n",
"jax_infer = lanfactory.trainers.JaxMLPFactory(\n",
"\t network_config=network_config,\n",
" train=False,\n",
" )"
Expand Down
81 changes: 31 additions & 50 deletions docs/basic_tutorial/basic_tutorial_lan_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -211,64 +211,39 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we set up dataloaders for training with pytorch. The `LANfactory` uses custom dataloaders, taking into account particularities of the expected training data.\n",
"Specifically, we expect to receive a bunch of training data files (the present example generates only one), where each file hosts a large number of training examples. \n",
"So we want to define a dataloader which spits out batches from data with a specific training data file, and keeps checking when to load in a new file. \n",
"The way this is implemented here, is via the `DatasetTorch` class in `lanfactory.trainers`, which inherits from `torch.utils.data.Dataset` and prespecifies a `batch_size`. Finally this is supplied to a [`DataLoader`](https://pytorch.org/docs/stable/data.html), for which we keep the `batch_size` argument at 0.\n",
"Next we set up dataloaders for training with pytorch. The `LANfactory` provides convenient helper functions for this.\n",
"\n",
"The `DatasetTorch` class is then called as an iterator via the DataLoader and takes care of batching as well as file loading internally. \n",
"The `make_train_valid_dataloaders` function handles:\n",
"- Splitting your data files into training and validation sets\n",
"- Creating the appropriate `DatasetTorch` objects\n",
"- Wrapping them in PyTorch `DataLoader` objects with sensible defaults\n",
"\n",
"You may choose your own way of defining the `DataLoader` classes, downstream you are simply expected to supply one."
"Under the hood, this uses the `DatasetTorch` class which handles batching and file loading internally."
]
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \n",
"TRAINING_TYPE = 'lan' # 'lan', 'cpn', 'opn'\n",
"\n",
"# MAKE DATALOADERS\n",
"\n",
"# List of datafiles (here only one)\n",
"folder_ = Path(\"torch_nb_data/lan_mlp\") / MODEL # + \"/training_data_0_nbins_0_n_1000/\"\n",
"file_list_ = [str(p) for p in folder_.iterdir()]\n",
"# List of datafiles\n",
"folder_ = Path(\"torch_nb_data/lan_mlp\") / MODEL\n",
"file_list_ = list(folder_.glob(\"*.pickle\"))\n",
"\n",
"# Training dataset\n",
"torch_training_dataset = lanfactory.trainers.DatasetTorch(\n",
"# Create train and validation dataloaders with a single function call\n",
"torch_training_dataloader, torch_validation_dataloader, input_dim = lanfactory.trainers.make_train_valid_dataloaders(\n",
" file_ids=file_list_,\n",
" batch_size=train_config[\"cpu_batch_size\"],\n",
" features_key=f\"{TRAINING_TYPE}_data\",\n",
" label_key=f\"{TRAINING_TYPE}_labels\",\n",
" label_lower_bound=np.log(1e-10)\n",
")\n",
"\n",
"torch_training_dataloader = torch.utils.data.DataLoader(\n",
" torch_training_dataset,\n",
" shuffle=True,\n",
" batch_size=None,\n",
" num_workers=1,\n",
" pin_memory=True,\n",
")\n",
"\n",
"# Validation dataset\n",
"torch_validation_dataset = lanfactory.trainers.DatasetTorch(\n",
" file_ids=file_list_,\n",
" batch_size=train_config[\"cpu_batch_size\"],\n",
" features_key=f\"{TRAINING_TYPE}_data\",\n",
" label_key=f\"{TRAINING_TYPE}_labels\",\n",
" label_lower_bound=np.log(1e-10)\n",
" network_type=\"lan\",\n",
" train_val_split=0.9,\n",
")\n",
"\n",
"torch_validation_dataloader = torch.utils.data.DataLoader(\n",
" torch_validation_dataset,\n",
" shuffle=True,\n",
" batch_size=None,\n",
" num_workers=1,\n",
" pin_memory=True,\n",
")"
"print(f\"Training batches: {len(torch_training_dataloader)}\")\n",
"print(f\"Validation batches: {len(torch_validation_dataloader)}\")\n",
"print(f\"Input dimension: {input_dim}\")"
]
},
{
Expand All @@ -280,7 +255,7 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -304,21 +279,26 @@
],
"source": [
"# LOAD NETWORK\n",
"net = lanfactory.trainers.TorchMLP(\n",
"# Option 1: Using the factory function (recommended)\n",
"net = lanfactory.trainers.TorchMLPFactory(\n",
" network_config=deepcopy(network_config),\n",
" input_dim=input_dim,\n",
" network_type=\"lan\",\n",
" input_shape=torch_training_dataset.input_dim,\n",
" save_folder=str(Path(\"/torch_nb_data/torch_models/\")),\n",
" generative_model_id=MODEL,\n",
")\n",
"\n",
"# Option 2: Direct instantiation (also works)\n",
"# net = lanfactory.trainers.TorchMLP(\n",
"# network_config=deepcopy(network_config),\n",
"# input_shape=input_dim,\n",
"# network_type=\"lan\",\n",
"# )\n",
"\n",
"# SAVE CONFIGS\n",
"lanfactory.utils.save_configs(\n",
" model_id=MODEL + \"_torch_\",\n",
" save_folder=str(Path(\"torch_nb_data/torch_models/lan/\" + MODEL + \"/\")),\n",
" network_config=network_config,\n",
" train_config=train_config,\n",
" allow_abs_path_folder_generation=True,\n",
")"
]
},
Expand Down Expand Up @@ -514,7 +494,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -534,10 +514,11 @@
"model_path = Path(\"torch_nb_data/torch_models\") / (MODEL + \"_lan\")\n",
"network_file_path = next(model_path.glob(\"*state_dict*\"))\n",
"\n",
"# LoadTorchMLPInfer automatically calls eval() for inference mode\n",
"network = lanfactory.trainers.LoadTorchMLPInfer(\n",
" model_file_path=network_file_path,\n",
" network_config=network_config,\n",
" input_dim=torch_training_dataset.input_dim,\n",
" input_dim=input_dim,\n",
")"
]
},
Expand Down
Loading
Loading