Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
9e5e8be
Add deprecation warnings for model_config attributes in HSSMBase
cpaniaguam Mar 11, 2026
74acc67
Refactor HSSMBase to support BaseModelConfig and improve model_config…
cpaniaguam Mar 11, 2026
2acc19d
Add model configuration building methods to BaseModelConfig and Confi…
cpaniaguam Mar 11, 2026
97bf90f
Refactor model configuration handling in HSSMBase and HSSM classes to…
cpaniaguam Mar 11, 2026
f75f5e4
Add properties to BaseModelConfig for parameter and extra field counts
cpaniaguam Mar 11, 2026
b9b4839
Refactor RLSSM attributes to use public naming convention for configu…
cpaniaguam Mar 11, 2026
821392f
Refactor test_rlssm_panel_attrs to use public attributes for particip…
cpaniaguam Mar 11, 2026
07169c9
Refactor HSSMBase to streamline model configuration handling and upda…
cpaniaguam Mar 11, 2026
0395ec2
Refactor BaseModelConfig and RLSSMConfig by removing unused abstract …
cpaniaguam Mar 11, 2026
626301d
Refactor HSSM class to remove Config inheritance and add initializati…
cpaniaguam Mar 11, 2026
6c13443
Refactor RLSSM class to remove RLSSMConfig inheritance and streamline…
cpaniaguam Mar 11, 2026
66296f0
Refactor Config and RLSSMConfig classes to use concrete types in meth…
cpaniaguam Mar 11, 2026
801f235
Update Config class parameter types for choices to improve type safety
cpaniaguam Mar 11, 2026
607874c
Update choices method to accept a tuple for model_config.choices
cpaniaguam Mar 11, 2026
9e25a32
Add tests for model configuration handling and choices logic in Config
cpaniaguam Mar 11, 2026
2b2d66b
Enhance HSSMBase initialization with safe default for constructor arg…
cpaniaguam Mar 11, 2026
d5b9d80
Update model_config validation to check for non-null choices
cpaniaguam Mar 11, 2026
9bf18ea
Refactor HSSM distribution method to use typed model_config attribute…
cpaniaguam Mar 11, 2026
9af3e95
Update test cases to use tuples for choices in model configuration
cpaniaguam Mar 12, 2026
c2e09d9
Refactor RLSSM to utilize model_config for list_params and loglik, en…
cpaniaguam Mar 12, 2026
1aa19f2
Fix typo in comment regarding model_config choices validation
cpaniaguam Mar 12, 2026
0ea0998
Refactor RLSSM tests to access model configuration attributes directl…
cpaniaguam Mar 12, 2026
8f526f4
Update attribute comparison in compare_hssm_class_attributes to use m…
cpaniaguam Mar 12, 2026
5dd68a5
Update test assertions to access model configuration attributes directly
cpaniaguam Mar 12, 2026
7054ccd
Refactor model configuration normalization to streamline choices hand…
cpaniaguam Mar 12, 2026
5e816bc
Refactor choices handling in Config class to improve clarity and logging
cpaniaguam Mar 12, 2026
9f6a7ef
Refactor _normalize_model_config_with_choices to improve input handli…
cpaniaguam Mar 12, 2026
49415ab
Refactor likelihood callable construction to simplify logic and enhan…
cpaniaguam Mar 12, 2026
4452f36
Refactor _make_model_distribution to utilize model_config for loglik …
cpaniaguam Mar 12, 2026
c34e562
Fix formatting in HSSM class for consistency in likelihood callable p…
cpaniaguam Mar 12, 2026
3e86974
Fix formatting in HSSM class for consistency in likelihood callable p…
cpaniaguam Mar 12, 2026
4a5aefc
Refactor HSSM class to use typed model_config attributes directly and…
cpaniaguam Mar 12, 2026
cdc7763
Restore make_model_dist in HSSM
cpaniaguam Mar 13, 2026
e3cbcb7
Remove deprecated properties and methods from HSSMBase class
cpaniaguam Mar 13, 2026
7e481e0
Enhance HSSMBase class to prevent overwriting _init_args if already s…
cpaniaguam Mar 13, 2026
3432bca
Clarify model_config parameter documentation in HSSMBase class to spe…
cpaniaguam Mar 13, 2026
31bd6f1
Enhance HSSMBase class documentation to clarify filtering of internal…
cpaniaguam Mar 13, 2026
296810b
Update model_config parameter documentation in HSSM class to support …
cpaniaguam Mar 13, 2026
95779bc
Add test to validate external model config fallback in _build_model_c…
cpaniaguam Mar 13, 2026
37ea9be
Update sampling parameters in test_rlssm_sample_smoke for speed
cpaniaguam Mar 17, 2026
9a37dd8
Add RLSSM quickstart notebook for model instantiation and sampling de…
cpaniaguam Mar 17, 2026
43ec652
Add RLSSM Quickstart tutorial to navigation and plugins
cpaniaguam Mar 17, 2026
7f1e6ff
Remove redundant next steps and streamline summary in RLSSM quickstar…
cpaniaguam Mar 17, 2026
e604406
Refactor RLSSM class to use model_config instead of rlssm_config for …
cpaniaguam Mar 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
351 changes: 351 additions & 0 deletions docs/tutorials/rlssm_quickstart.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "1b9b429d",
"metadata": {},
"source": [
"# RLSSM Quickstart: Instantiation, Model Building, and Sampling\n",
"\n",
"This notebook provides a minimal end-to-end demonstration of the `RLSSM` class:\n",
"\n",
"1. **Load** a balanced-panel two-armed bandit dataset\n",
"2. **Define** an annotated learning function and the angle SSM log-likelihood\n",
"3. **Configure** and **instantiate** an `RLSSM` model\n",
"4. **Inspect** the built Bambi / PyMC model\n",
"5. **Run** a minimal 2-draw sampling smoke test\n",
"\n",
"For a full treatment — simulating data, hierarchical formulas, meaningful sampling, and posterior visualization — see:\n",
"- [rlssm_tutorial.ipynb](rlssm_tutorial.ipynb)\n",
"- [add_custom_rlssm_model.ipynb](add_custom_rlssm_model.ipynb)"
]
},
{
"cell_type": "markdown",
"id": "bf38d7f7",
"metadata": {},
"source": [
"## 1. Imports and Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d764731",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import hssm\n",
"from hssm import RLSSM, RLSSMConfig\n",
"from hssm.distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx\n",
"from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise\n",
"from hssm.utils import annotate_function\n",
"\n",
"# RLSSM requires float32 throughout (JAX default).\n",
"hssm.set_floatX(\"float32\", update_jax=True)"
]
},
{
"cell_type": "markdown",
"id": "df12303f",
"metadata": {},
"source": [
"## 2. Load the Dataset\n",
"\n",
"We use a small synthetic two-armed bandit dataset from the HSSM test fixtures. \n",
"It is a **balanced panel**: every participant has the same number of trials. \n",
"Columns: `participant_id`, `trial_id`, `rt`, `response`, `feedback`.\n",
"\n",
"> **Note:** You can also generate data with\n",
"> [`ssm-simulators`](https://github.com/AlexanderFengler/ssm-simulators).\n",
"> See `rlssm_tutorial.ipynb` for an example."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2ef5f6e",
"metadata": {},
"outputs": [],
"source": [
"# Path relative to docs/tutorials/ when running inside the HSSM repo.\n",
"_fixture_path = Path(\"../../tests/fixtures/rldm_data.npy\")\n",
"raw = np.load(_fixture_path, allow_pickle=True).item()\n",
"data = pd.DataFrame(raw[\"data\"])\n",
"\n",
"n_participants = data[\"participant_id\"].nunique()\n",
"n_trials = len(data) // n_participants\n",
"\n",
"print(data.head())\n",
"print(f\"\\nParticipants: {n_participants} | Trials per participant: {n_trials}\")"
]
},
{
"cell_type": "markdown",
"id": "8c310290",
"metadata": {},
"source": [
"## 3. Define the Learning Process\n",
"\n",
"The RL learning process is a JAX function that, given a subject's trial sequence, computes\n",
"the trial-wise drift rate `v` via a Q-learning update rule. \n",
"\n",
"`annotate_function` attaches `.inputs`, `.outputs`, and (optionally) `.computed` metadata\n",
"that the RLSSM likelihood builder uses to automatically construct the input matrix for the\n",
"decision process.\n",
"\n",
"- **inputs** — columns that the function reads (free parameters + data columns)\n",
"- **outputs** — what the function produces (here: `v`, the drift rate)\n",
"\n",
"Here we annotate the built-in `compute_v_subject_wise` function, which implements a simple\n",
"Rescorla-Wagner Q-learning update for a two-armed bandit task."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbcea122",
"metadata": {},
"outputs": [],
"source": [
"compute_v_annotated = annotate_function(\n",
" inputs=[\"rl_alpha\", \"scaler\", \"response\", \"feedback\"],\n",
" outputs=[\"v\"],\n",
")(compute_v_subject_wise)\n",
"\n",
"print(\"Learning function inputs :\", compute_v_annotated.inputs)\n",
"print(\"Learning function outputs:\", compute_v_annotated.outputs)"
]
},
{
"cell_type": "markdown",
"id": "7a03305a",
"metadata": {},
"source": [
"## 4. Define the Decision (SSM) Log-Likelihood\n",
"\n",
"The decision process uses the **angle model** likelihood, loaded from an ONNX file.\n",
"`make_jax_matrix_logp_funcs_from_onnx` returns a JAX callable that accepts a\n",
"2-D matrix whose columns are `[v, a, z, t, theta, rt, response]` and returns\n",
"per-trial log-probabilities.\n",
"\n",
"We then annotate that callable so the builder knows:\n",
"- which columns the matrix contains (`inputs`)\n",
"- that `v` itself is *computed* by the learning function (not a free parameter)\n",
"\n",
"The ONNX file is loaded from the local test fixture when running inside the HSSM\n",
"repository; otherwise it is downloaded from the HuggingFace Hub (`franklab/HSSM`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60bbc036",
"metadata": {},
"outputs": [],
"source": [
"# Use the local fixture when available; fall back to HuggingFace download.\n",
"_local_onnx = Path(\"../../tests/fixtures/angle.onnx\").resolve()\n",
"_onnx_model = str(_local_onnx) if _local_onnx.exists() else \"angle.onnx\"\n",
"\n",
"_angle_logp_jax = make_jax_matrix_logp_funcs_from_onnx(model=_onnx_model)\n",
"\n",
"angle_logp_func = annotate_function(\n",
" inputs=[\"v\", \"a\", \"z\", \"t\", \"theta\", \"rt\", \"response\"],\n",
" outputs=[\"logp\"],\n",
" computed={\"v\": compute_v_annotated},\n",
")(_angle_logp_jax)\n",
"\n",
"print(\"SSM logp inputs :\", angle_logp_func.inputs)\n",
"print(\"SSM logp outputs:\", angle_logp_func.outputs)\n",
"print(\"Computed deps :\", list(angle_logp_func.computed.keys()))"
]
},
{
"cell_type": "markdown",
"id": "cf8f5b63",
"metadata": {},
"source": [
"## 5. Configure the Model with `RLSSMConfig`\n",
"\n",
"`RLSSMConfig` collects all the information the RLSSM class needs:\n",
"\n",
"| Field | Purpose |\n",
"|-------|---------|\n",
"| `model_name` | Identifier string for the configuration |\n",
"| `decision_process` | Name of the SSM (e.g. `\"angle\"`) |\n",
"| `list_params` | Ordered list of *free* parameters to sample |\n",
"| `params_default` | Starting / default values for each parameter |\n",
"| `bounds` | Prior bounds for each parameter |\n",
"| `learning_process` | Dict mapping computed param name → annotated learning function |\n",
"| `extra_fields` | Extra data columns required by the learning function |\n",
"| `ssm_logp_func` | Annotated JAX callable for the decision-process likelihood |"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4beba1bc",
"metadata": {},
"outputs": [],
"source": [
"rlssm_config = RLSSMConfig(\n",
" model_name=\"rlssm_angle_quickstart\",\n",
" loglik_kind=\"approx_differentiable\",\n",
" decision_process=\"angle\",\n",
" decision_process_loglik_kind=\"approx_differentiable\",\n",
" learning_process_loglik_kind=\"blackbox\",\n",
" list_params=[\"rl_alpha\", \"scaler\", \"a\", \"theta\", \"t\", \"z\"],\n",
" params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5],\n",
" bounds={\n",
" \"rl_alpha\": (0.0, 1.0),\n",
" \"scaler\": (0.0, 10.0),\n",
" \"a\": (0.1, 3.0),\n",
" \"theta\": (-0.1, 0.1),\n",
" \"t\": (0.001, 1.0),\n",
" \"z\": (0.1, 0.9),\n",
" },\n",
" learning_process={\"v\": compute_v_annotated},\n",
" response=[\"rt\", \"response\"],\n",
" choices=[0, 1],\n",
" extra_fields=[\"feedback\"],\n",
" ssm_logp_func=angle_logp_func,\n",
")\n",
"\n",
"print(\"Model name :\", rlssm_config.model_name)\n",
"print(\"Free params :\", rlssm_config.list_params)"
]
},
{
"cell_type": "markdown",
"id": "924ee4c7",
"metadata": {},
"source": [
"## 6. Instantiate the `RLSSM` Model\n",
"\n",
"Passing `data` and `rlssm_config` to `RLSSM`:\n",
"\n",
"- validates the balanced-panel requirement\n",
"- builds a differentiable PyTensor Op that chains the RL learning step and the\n",
" angle log-likelihood\n",
"- constructs the Bambi / PyMC model internally\n",
"\n",
"Note that `v` (the drift rate) is *not* a free parameter — it is computed inside\n",
"the Op by the Q-learning update and therefore does not appear in `model.params`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f8da79a",
"metadata": {},
"outputs": [],
"source": [
"model = RLSSM(data=data, rlssm_config=rlssm_config)\n",
"\n",
"assert isinstance(model, RLSSM)\n",
"print(\"Model type :\", type(model).__name__)\n",
"print(\"Participants :\", model.n_participants)\n",
"print(\"Trials/subj :\", model.n_trials)\n",
"print(\"Free parameters :\", list(model.params.keys()))\n",
"assert \"rl_alpha\" in model.params, \"rl_alpha must be a free parameter\"\n",
"assert \"v\" not in model.params, \"v is computed, not a free parameter\"\n",
"model"
]
},
{
"cell_type": "markdown",
"id": "f7f39940",
"metadata": {},
"source": [
"## 7. Inspect the Built Model\n",
"\n",
"After construction, `model.model` exposes the underlying **Bambi model** and\n",
"`model.pymc_model` exposes the **PyMC model** context — useful for debugging\n",
"or customizing priors."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0558ad4",
"metadata": {},
"outputs": [],
"source": [
"print(\"=== Bambi model ===\")\n",
"print(model.model)\n",
"\n",
"print(\"\\n=== PyMC model ===\")\n",
"print(model.pymc_model)"
]
},
{
"cell_type": "markdown",
"id": "f4e50110",
"metadata": {},
"source": [
"## 8. Sampling\n",
"\n",
"A minimal sampling run — 2 draws, 2 tuning steps, 1 chain — confirms that the full\n",
"computational graph (Q-learning scan → angle logp → NUTS gradient) is wired correctly."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96ce3238",
"metadata": {},
"outputs": [],
"source": [
"trace = model.sample(draws=2, tune=2, chains=1, cores=1, sampler=\"numpyro\", target_accept=0.9)\n",
"\n",
"assert trace is not None\n",
"print(trace)"
]
},
{
"cell_type": "markdown",
"id": "a784a468",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"This notebook showed how to:\n",
"\n",
"1. Load a balanced-panel dataset (`rldm_data.npy`)\n",
"2. Annotate a Q-learning function with `annotate_function`\n",
"3. Load the angle ONNX likelihood and annotate it so the builder can assemble the input matrix\n",
"4. Define an `RLSSMConfig` and pass it to `RLSSM`\n",
"5. Confirm model structure (free params, Bambi / PyMC objects)\n",
"6. Run a 2-draw sampling smoke test that returns an `arviz.InferenceData` object"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "hssm",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ nav:
- Hierarchical Variational Inference: tutorials/variational_inference_hierarchical.ipynb
- Using HSSM low-level API directly with PyMC: tutorials/pymc.ipynb
- Reinforcement Learning - Sequential Sampling Models (RLSSM): tutorials/rlssm_tutorial.ipynb
- RLSSM Quickstart: tutorials/rlssm_quickstart.ipynb
- Add custom RLSSM models: tutorials/add_custom_rlssm_model.ipynb
- Custom models: tutorials/jax_callable_contribution_onnx_example.ipynb
- Custom models from onnx files: tutorials/blackbox_contribution_onnx_example.ipynb
Expand Down Expand Up @@ -91,6 +92,7 @@ plugins:
- tutorials/hssm_tutorial_workshop_2.ipynb
- tutorials/add_custom_rlssm_model.ipynb
- tutorials/rlssm_tutorial.ipynb
- tutorials/rlssm_quickstart.ipynb
- tutorials/lapse_prob_and_dist.ipynb
- tutorials/plotting.ipynb
- tutorials/scientific_workflow_hssm.ipynb
Expand Down
Loading