Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .claude/rules/uv.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
description: Enforce uv as the package manager for all Python operations
globs:
- "**/*.py"
- "**/pyproject.toml"
---

- Always use `uv run` to execute commands — never bare `python`, `pytest`, `ruff`, or other tools.
- Never use `pip install` — use `uv sync` (with `--extra` or `--all-groups` flags) to manage dependencies.
Copy link

Copilot AI Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

“Never use pip install” is too broad given the repo’s README still documents pip install lanfactory for end users. Consider scoping this rule explicitly to repository development/CI (e.g., “When working in this repo, don’t use pip; use uv…”) so it doesn’t conflict with published installation guidance.

Suggested change
- Never use `pip install` use `uv sync` (with `--extra` or `--all-groups` flags) to manage dependencies.
- When working in this repo (local development, CI, or scripts), do not use `pip install`; instead use `uv sync` (with `--extra` or `--all-groups` flags) to manage dependencies. End‑user installation instructions in README or external docs may still use `pip install` where appropriate.

Copilot uses AI. Check for mistakes.
- The `uv.lock` file is the source of truth for resolved dependency versions.
- When adding dependencies, add them to `pyproject.toml` and run `uv sync`.
Comment on lines +10 to +11
Copy link

Copilot AI Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rule says uv.lock is the source of truth, but the repository currently ignores uv.lock in .gitignore (so it won’t be committed/reviewed). To make this enforceable, either stop ignoring and commit uv.lock, or adjust the rule to reflect the current workflow (no checked-in lockfile).

Suggested change
- The `uv.lock` file is the source of truth for resolved dependency versions.
- When adding dependencies, add them to `pyproject.toml` and run `uv sync`.
- `pyproject.toml` is the source of truth for dependencies; `uv.lock` is a local resolution artifact and is not committed.
- When adding dependencies, add them to `pyproject.toml` and run `uv sync` to update your local `uv.lock`.

Copilot uses AI. Check for mistakes.
138 changes: 138 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# LANfactory — Project Context for Claude

## What is LANfactory?

Lightweight Python package for training Likelihood Approximation Networks (LANs), Choice Probability Networks (CPNs), and Option Probability Networks (OPNs) using PyTorch or JAX/Flax backends. Trained networks are exported to ONNX format and uploaded to HuggingFace for consumption by HSSM. This package sits in the middle of the HSSM ecosystem: it depends on ssm-simulators for training data and produces the neural network artifacts that HSSM uses at inference time. For ecosystem-wide context, see the HSSMSpine repo.

## Project Structure

```
src/lanfactory/ # Main package
cli/ # Typer CLIs: jaxtrain, torchtrain, transform-onnx, upload-hf, download-hf
config/ # Default network and training configs (LAN, CPN, OPN)
trainers/ # Training implementations (torch_mlp.py, jax_mlp.py)
onnx/ # PyTorch → ONNX export
hf/ # HuggingFace Hub integration (upload, download, model cards)
utils/ # Config save/load, MLflow utilities
tests/ # pytest suite (trainers, CLI, ONNX, HuggingFace, E2E)
docs/ # MkDocs documentation + tutorial notebooks
notebooks/ # Test notebooks
```

## Build & Tooling

- **Build system:** setuptools (pure Python, no compiled extensions)
- **Package manager:** uv (with `uv.lock`)
Copy link

Copilot AI Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uv.lock is currently ignored by git in this repo (.gitignore contains uv.lock), so stating “uv.lock is the source of truth” (and implying it’s present) is misleading. Either commit uv.lock and remove it from .gitignore, or update this doc to reflect that the repo doesn’t track a lockfile.

Suggested change
- **Package manager:** uv (with `uv.lock`)
- **Package manager:** uv (no `uv.lock` committed; dependencies resolved from `pyproject.toml`)

Copilot uses AI. Check for mistakes.
- **Python:** >3.10, <3.14 (classifiers target 3.11, 3.12, 3.13)
- **Linting:** ruff (line length 120, via pre-commit)
- **Type checking:** mypy
- **No system dependencies** — unlike ssm-simulators, this is pure Python + PyTorch/Flax

## Common Commands

```bash
# Install all dependencies (dev + optional)
Copy link

Copilot AI Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uv sync --all-groups installs dependency groups (e.g. dev) but does not install optional extras, so the comment “dev + optional” is inaccurate. Consider either changing the text to “dev” only, or updating the command to include extras (e.g., add --all-extras or explicit --extra ...).

Suggested change
# Install all dependencies (dev + optional)
# Install all dependency groups (e.g. dev)

Copilot uses AI. Check for mistakes.
uv sync --all-groups

# Run tests
uv run pytest tests/

# Lint & format
uv run ruff check src/lanfactory && uv run ruff format --check .

# Build docs
uv run mkdocs build
uv run mkdocs serve

# Train a network (PyTorch)
uv run torchtrain --config-path <yaml> --training-data-folder <dir> --networks-path-base <dir>

# Train a network (JAX)
uv run jaxtrain --config-path <yaml> --training-data-folder <dir> --networks-path-base <dir>

# Export PyTorch model to ONNX
uv run transform-onnx --network-config-file config.pickle --state-dict-file model.pt \
--input-shape 6 --output-onnx-file model.onnx

# Upload trained models to HuggingFace
uv run upload-hf --model-folder <dir> --network-type lan --model-name ddm

# Download models from HuggingFace
uv run download-hf --network-type lan --model-name ddm --output-folder <dir>
```

## Key Architecture Patterns

### Network Types

| Type | Full Name | Output | Loss | Use Case |
|------|-----------|--------|------|----------|
| LAN | Likelihood Approximation Network | logprob | Huber | Log-likelihood approximation |
| CPN | Choice Probability Network | logits | BCE with logits | Choice probability estimation |
| OPN | Option Probability Network | logits | BCE with logits | Option probability estimation |

All three use the same MLP architecture (`[100, 100, 1]` default, tanh activations)
but differ in output type and loss function.

### Training Backends

- **PyTorch** (`torchtrain` CLI, `trainers/torch_mlp.py`) — primary backend.
Supports CUDA, ONNX export, full training loop with validation.
- **JAX/Flax** (`jaxtrain` CLI, `trainers/jax_mlp.py`) — alternative backend.
Uses optax optimizers. No native ONNX export (train in JAX, convert via PyTorch if needed).

### ONNX Export Pipeline

PyTorch model → `torch.onnx.export()` → `.onnx` file. This is the format
HSSM consumes at runtime. Only PyTorch models can be directly exported to ONNX.

### HuggingFace Integration

- **Upload:** `lanfactory.hf.upload_model()` — uploads `.onnx`, `.pt`, config pickles,
and auto-generated README to `franklab/HSSM` on HuggingFace.
Requires `model_card.yaml` in the model folder.
- **Download:** `lanfactory.hf.download_model()` — downloads by network type + model name.
- **Default repo:** `franklab/HSSM`
- **Optional dependency:** `huggingface-hub>=0.20.0` (install via `uv sync --extra hf`)

### Config System

Training configs are YAML files parsed by the CLI. Key fields:
- `NETWORK_TYPE`: `lan`, `cpn`, or `opn`
- `layer_sizes`, `activations`: network architecture
- `n_epochs`, `learning_rate`, `loss`, `optimizer`: training hyperparams
- `cpu_batch_size`, `gpu_batch_size`: device-specific batch sizes

Default configs available in `lanfactory.config.network_configs`.

### MLflow Integration

Optional experiment tracking via MLflow. CLI flags: `--run-name`, `--experiment-name`,
`--tracking-uri`, `--artifact-location`. Supports resuming runs via `--run-id`.
Comment on lines +110 to +111
Copy link

Copilot AI Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MLflow CLI flag names in this section don’t match the actual CLI options in lanfactory.cli.{torch_train,jax_train} (they use --mlflow-run-name, --mlflow-experiment-name, --mlflow-tracking-uri, --mlflow-artifact-location, --mlflow-run-id). Updating these flags here will prevent users from copying commands that fail.

Suggested change
Optional experiment tracking via MLflow. CLI flags: `--run-name`, `--experiment-name`,
`--tracking-uri`, `--artifact-location`. Supports resuming runs via `--run-id`.
Optional experiment tracking via MLflow. CLI flags: `--mlflow-run-name`, `--mlflow-experiment-name`,
`--mlflow-tracking-uri`, `--mlflow-artifact-location`. Supports resuming runs via `--mlflow-run-id`.

Copilot uses AI. Check for mistakes.

## CLI Entry Points

| Command | Module | Purpose |
|---------|--------|---------|
| `torchtrain` | `lanfactory.cli.torch_train` | Train PyTorch networks from YAML config |
| `jaxtrain` | `lanfactory.cli.jax_train` | Train JAX networks from YAML config |
| `transform-onnx` | `lanfactory.onnx.transform_onnx` | Convert PyTorch model → ONNX |
| `upload-hf` | `lanfactory.cli.upload_hf` | Upload trained models to HuggingFace |
| `download-hf` | `lanfactory.cli.download_hf` | Download models from HuggingFace |

## CI Workflows

| Workflow | Purpose |
|----------|---------|
| `run_tests.yml` | Tests on Python 3.11/3.12/3.13 + ruff lint/format + codecov |
| `build_wheels.yml` | Build sdist, upload to TestPyPI → PyPI on release publish |

## Known Issues

- `__init__.py` version (`0.5.3`) is out of sync with `pyproject.toml` (`0.6.1`)

## Compaction

When compacting, preserve: file list of modified files, the three network types
(LAN/CPN/OPN) and their differences, CLI entry points, ONNX export flow,
HuggingFace upload/download interface, and all test commands.
Loading