diff --git a/docs/tutorials/rlssm_quickstart.ipynb b/docs/tutorials/rlssm_quickstart.ipynb new file mode 100644 index 000000000..e4ddc7636 --- /dev/null +++ b/docs/tutorials/rlssm_quickstart.ipynb @@ -0,0 +1,351 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1b9b429d", + "metadata": {}, + "source": [ + "# RLSSM Quickstart: Instantiation, Model Building, and Sampling\n", + "\n", + "This notebook provides a minimal end-to-end demonstration of the `RLSSM` class:\n", + "\n", + "1. **Load** a balanced-panel two-armed bandit dataset\n", + "2. **Define** an annotated learning function and the angle SSM log-likelihood\n", + "3. **Configure** and **instantiate** an `RLSSM` model\n", + "4. **Inspect** the built Bambi / PyMC model\n", + "5. **Run** a minimal 2-draw sampling smoke test\n", + "\n", + "For a full treatment — simulating data, hierarchical formulas, meaningful sampling, and posterior visualization — see:\n", + "- [rlssm_tutorial.ipynb](rlssm_tutorial.ipynb)\n", + "- [add_custom_rlssm_model.ipynb](add_custom_rlssm_model.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "bf38d7f7", + "metadata": {}, + "source": [ + "## 1. Imports and Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d764731", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import hssm\n", + "from hssm import RLSSM, RLSSMConfig\n", + "from hssm.distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx\n", + "from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise\n", + "from hssm.utils import annotate_function\n", + "\n", + "# RLSSM requires float32 throughout (JAX default).\n", + "hssm.set_floatX(\"float32\", update_jax=True)" + ] + }, + { + "cell_type": "markdown", + "id": "df12303f", + "metadata": {}, + "source": [ + "## 2. Load the Dataset\n", + "\n", + "We use a small synthetic two-armed bandit dataset from the HSSM test fixtures. \n", + "It is a **balanced panel**: every participant has the same number of trials. \n", + "Columns: `participant_id`, `trial_id`, `rt`, `response`, `feedback`.\n", + "\n", + "> **Note:** You can also generate data with\n", + "> [`ssm-simulators`](https://github.com/AlexanderFengler/ssm-simulators).\n", + "> See `rlssm_tutorial.ipynb` for an example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2ef5f6e", + "metadata": {}, + "outputs": [], + "source": [ + "# Path relative to docs/tutorials/ when running inside the HSSM repo.\n", + "_fixture_path = Path(\"../../tests/fixtures/rldm_data.npy\")\n", + "raw = np.load(_fixture_path, allow_pickle=True).item()\n", + "data = pd.DataFrame(raw[\"data\"])\n", + "\n", + "n_participants = data[\"participant_id\"].nunique()\n", + "n_trials = len(data) // n_participants\n", + "\n", + "print(data.head())\n", + "print(f\"\\nParticipants: {n_participants} | Trials per participant: {n_trials}\")" + ] + }, + { + "cell_type": "markdown", + "id": "8c310290", + "metadata": {}, + "source": [ + "## 3. Define the Learning Process\n", + "\n", + "The RL learning process is a JAX function that, given a subject's trial sequence, computes\n", + "the trial-wise drift rate `v` via a Q-learning update rule. \n", + "\n", + "`annotate_function` attaches `.inputs`, `.outputs`, and (optionally) `.computed` metadata\n", + "that the RLSSM likelihood builder uses to automatically construct the input matrix for the\n", + "decision process.\n", + "\n", + "- **inputs** — columns that the function reads (free parameters + data columns)\n", + "- **outputs** — what the function produces (here: `v`, the drift rate)\n", + "\n", + "Here we annotate the built-in `compute_v_subject_wise` function, which implements a simple\n", + "Rescorla-Wagner Q-learning update for a two-armed bandit task." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbcea122", + "metadata": {}, + "outputs": [], + "source": [ + "compute_v_annotated = annotate_function(\n", + " inputs=[\"rl_alpha\", \"scaler\", \"response\", \"feedback\"],\n", + " outputs=[\"v\"],\n", + ")(compute_v_subject_wise)\n", + "\n", + "print(\"Learning function inputs :\", compute_v_annotated.inputs)\n", + "print(\"Learning function outputs:\", compute_v_annotated.outputs)" + ] + }, + { + "cell_type": "markdown", + "id": "7a03305a", + "metadata": {}, + "source": [ + "## 4. Define the Decision (SSM) Log-Likelihood\n", + "\n", + "The decision process uses the **angle model** likelihood, loaded from an ONNX file.\n", + "`make_jax_matrix_logp_funcs_from_onnx` returns a JAX callable that accepts a\n", + "2-D matrix whose columns are `[v, a, z, t, theta, rt, response]` and returns\n", + "per-trial log-probabilities.\n", + "\n", + "We then annotate that callable so the builder knows:\n", + "- which columns the matrix contains (`inputs`)\n", + "- that `v` itself is *computed* by the learning function (not a free parameter)\n", + "\n", + "The ONNX file is loaded from the local test fixture when running inside the HSSM\n", + "repository; otherwise it is downloaded from the HuggingFace Hub (`franklab/HSSM`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60bbc036", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the local fixture when available; fall back to HuggingFace download.\n", + "_local_onnx = Path(\"../../tests/fixtures/angle.onnx\").resolve()\n", + "_onnx_model = str(_local_onnx) if _local_onnx.exists() else \"angle.onnx\"\n", + "\n", + "_angle_logp_jax = make_jax_matrix_logp_funcs_from_onnx(model=_onnx_model)\n", + "\n", + "angle_logp_func = annotate_function(\n", + " inputs=[\"v\", \"a\", \"z\", \"t\", \"theta\", \"rt\", \"response\"],\n", + " outputs=[\"logp\"],\n", + " computed={\"v\": compute_v_annotated},\n", + ")(_angle_logp_jax)\n", + "\n", + "print(\"SSM logp inputs :\", angle_logp_func.inputs)\n", + "print(\"SSM logp outputs:\", angle_logp_func.outputs)\n", + "print(\"Computed deps :\", list(angle_logp_func.computed.keys()))" + ] + }, + { + "cell_type": "markdown", + "id": "cf8f5b63", + "metadata": {}, + "source": [ + "## 5. Configure the Model with `RLSSMConfig`\n", + "\n", + "`RLSSMConfig` collects all the information the RLSSM class needs:\n", + "\n", + "| Field | Purpose |\n", + "|-------|---------|\n", + "| `model_name` | Identifier string for the configuration |\n", + "| `decision_process` | Name of the SSM (e.g. `\"angle\"`) |\n", + "| `list_params` | Ordered list of *free* parameters to sample |\n", + "| `params_default` | Starting / default values for each parameter |\n", + "| `bounds` | Prior bounds for each parameter |\n", + "| `learning_process` | Dict mapping computed param name → annotated learning function |\n", + "| `extra_fields` | Extra data columns required by the learning function |\n", + "| `ssm_logp_func` | Annotated JAX callable for the decision-process likelihood |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4beba1bc", + "metadata": {}, + "outputs": [], + "source": [ + "rlssm_config = RLSSMConfig(\n", + " model_name=\"rlssm_angle_quickstart\",\n", + " loglik_kind=\"approx_differentiable\",\n", + " decision_process=\"angle\",\n", + " decision_process_loglik_kind=\"approx_differentiable\",\n", + " learning_process_loglik_kind=\"blackbox\",\n", + " list_params=[\"rl_alpha\", \"scaler\", \"a\", \"theta\", \"t\", \"z\"],\n", + " params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5],\n", + " bounds={\n", + " \"rl_alpha\": (0.0, 1.0),\n", + " \"scaler\": (0.0, 10.0),\n", + " \"a\": (0.1, 3.0),\n", + " \"theta\": (-0.1, 0.1),\n", + " \"t\": (0.001, 1.0),\n", + " \"z\": (0.1, 0.9),\n", + " },\n", + " learning_process={\"v\": compute_v_annotated},\n", + " response=[\"rt\", \"response\"],\n", + " choices=[0, 1],\n", + " extra_fields=[\"feedback\"],\n", + " ssm_logp_func=angle_logp_func,\n", + ")\n", + "\n", + "print(\"Model name :\", rlssm_config.model_name)\n", + "print(\"Free params :\", rlssm_config.list_params)" + ] + }, + { + "cell_type": "markdown", + "id": "924ee4c7", + "metadata": {}, + "source": [ + "## 6. Instantiate the `RLSSM` Model\n", + "\n", + "Passing `data` and `rlssm_config` to `RLSSM`:\n", + "\n", + "- validates the balanced-panel requirement\n", + "- builds a differentiable PyTensor Op that chains the RL learning step and the\n", + " angle log-likelihood\n", + "- constructs the Bambi / PyMC model internally\n", + "\n", + "Note that `v` (the drift rate) is *not* a free parameter — it is computed inside\n", + "the Op by the Q-learning update and therefore does not appear in `model.params`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f8da79a", + "metadata": {}, + "outputs": [], + "source": [ + "model = RLSSM(data=data, rlssm_config=rlssm_config)\n", + "\n", + "assert isinstance(model, RLSSM)\n", + "print(\"Model type :\", type(model).__name__)\n", + "print(\"Participants :\", model.n_participants)\n", + "print(\"Trials/subj :\", model.n_trials)\n", + "print(\"Free parameters :\", list(model.params.keys()))\n", + "assert \"rl_alpha\" in model.params, \"rl_alpha must be a free parameter\"\n", + "assert \"v\" not in model.params, \"v is computed, not a free parameter\"\n", + "model" + ] + }, + { + "cell_type": "markdown", + "id": "f7f39940", + "metadata": {}, + "source": [ + "## 7. Inspect the Built Model\n", + "\n", + "After construction, `model.model` exposes the underlying **Bambi model** and\n", + "`model.pymc_model` exposes the **PyMC model** context — useful for debugging\n", + "or customizing priors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0558ad4", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=== Bambi model ===\")\n", + "print(model.model)\n", + "\n", + "print(\"\\n=== PyMC model ===\")\n", + "print(model.pymc_model)" + ] + }, + { + "cell_type": "markdown", + "id": "f4e50110", + "metadata": {}, + "source": [ + "## 8. Sampling\n", + "\n", + "A minimal sampling run — 2 draws, 2 tuning steps, 1 chain — confirms that the full\n", + "computational graph (Q-learning scan → angle logp → NUTS gradient) is wired correctly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96ce3238", + "metadata": {}, + "outputs": [], + "source": [ + "trace = model.sample(draws=2, tune=2, chains=1, cores=1, sampler=\"numpyro\", target_accept=0.9)\n", + "\n", + "assert trace is not None\n", + "print(trace)" + ] + }, + { + "cell_type": "markdown", + "id": "a784a468", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook showed how to:\n", + "\n", + "1. Load a balanced-panel dataset (`rldm_data.npy`)\n", + "2. Annotate a Q-learning function with `annotate_function`\n", + "3. Load the angle ONNX likelihood and annotate it so the builder can assemble the input matrix\n", + "4. Define an `RLSSMConfig` and pass it to `RLSSM`\n", + "5. Confirm model structure (free params, Bambi / PyMC objects)\n", + "6. Run a 2-draw sampling smoke test that returns an `arviz.InferenceData` object" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hssm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mkdocs.yml b/mkdocs.yml index 66ab3e103..286ef08c8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -43,6 +43,7 @@ nav: - Hierarchical Variational Inference: tutorials/variational_inference_hierarchical.ipynb - Using HSSM low-level API directly with PyMC: tutorials/pymc.ipynb - Reinforcement Learning - Sequential Sampling Models (RLSSM): tutorials/rlssm_tutorial.ipynb + - RLSSM Quickstart: tutorials/rlssm_quickstart.ipynb - Add custom RLSSM models: tutorials/add_custom_rlssm_model.ipynb - Custom models: tutorials/jax_callable_contribution_onnx_example.ipynb - Custom models from onnx files: tutorials/blackbox_contribution_onnx_example.ipynb @@ -91,6 +92,7 @@ plugins: - tutorials/hssm_tutorial_workshop_2.ipynb - tutorials/add_custom_rlssm_model.ipynb - tutorials/rlssm_tutorial.ipynb + - tutorials/rlssm_quickstart.ipynb - tutorials/lapse_prob_and_dist.ipynb - tutorials/plotting.ipynb - tutorials/scientific_workflow_hssm.ipynb diff --git a/src/hssm/base.py b/src/hssm/base.py index 1b4e1db59..95d5c97ca 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -10,7 +10,6 @@ import logging from abc import ABC, abstractmethod from copy import deepcopy -from inspect import signature from os import PathLike from pathlib import Path from typing import Any, Callable, Literal, Optional, Union, cast, get_args @@ -29,9 +28,8 @@ from bambi.model_components import DistributionalComponent from bambi.transformations import transformations_namespace from pymc.model.transform.conditioning import do -from ssms.config import model_config as ssms_model_config -from hssm._types import LoglikKind, SupportedModels +from hssm._types import SupportedModels from hssm.data_validator import DataValidatorMixin from hssm.defaults import ( INITVAL_JITTER_SETTINGS, @@ -49,7 +47,7 @@ ) from . import plotting -from .config import Config, ModelConfig +from .config import BaseModelConfig from .param import Params from .param import UserParam as Param @@ -116,65 +114,12 @@ class HSSMBase(ABC, DataValidatorMixin, MissingDataMixin): A list of dictionaries specifying parameter specifications to include in the model. If left unspecified, defaults will be used for all parameter specifications. Defaults to None. - model_config : optional - A dictionary containing the model configuration information. If None is - provided, defaults will be used if there are any. Defaults to None. - Fields for this `dict` are usually: - - - `"list_params"`: a list of parameters indicating the parameters of the model. - The order in which the parameters are specified in this list is important. - Values for each parameter will be passed to the likelihood function in this - order. - - `"backend"`: Only used when `loglik_kind` is `approx_differentiable` and - an onnx file is supplied for the likelihood approximation network (LAN). - Valid values are `"jax"` or `"pytensor"`. It determines whether the LAN in - ONNX should be converted to `"jax"` or `"pytensor"`. If not provided, - `jax` will be used for maximum performance. - - `"default_priors"`: A `dict` indicating the default priors for each parameter. - - `"bounds"`: A `dict` indicating the boundaries for each parameter. In the case - of LAN, these bounds are training boundaries. - - `"rv"`: Optional. Can be a `RandomVariable` class containing the user's own - `rng_fn` function for sampling from the distribution that the user is - supplying. If not supplied, HSSM will automatically generate a - `RandomVariable` using the simulator identified by `model` from the - `ssm_simulators` package. If `model` is not supported in `ssm_simulators`, - a warning will be raised letting the user know that sampling from the - `RandomVariable` will result in errors. - - `"extra_fields"`: Optional. A list of strings indicating the additional - columns in `data` that will be passed to the likelihood function for - calculation. This is helpful if the likelihood function depends on data - other than the observed data and the parameter values. - loglik : optional - A likelihood function. Defaults to None. Requirements are: - - 1. if `loglik_kind` is `"analytical"` or `"blackbox"`, a pm.Distribution, a - pytensor Op, or a Python callable can be used. Signatures are: - - `pm.Distribution`: needs to have parameters specified exactly as listed in - `list_params` - - `pytensor.graph.Op` and `Callable`: needs to accept the parameters - specified exactly as listed in `list_params` - 2. If `loglik_kind` is `"approx_differentiable"`, then in addition to the - specifications above, a `str` or `Pathlike` can also be used to specify a - path to an `onnx` file. If a `str` is provided, HSSM will first look locally - for an `onnx` file. If that is not successful, HSSM will try to download - that `onnx` file from Hugging Face hub. - 3. It can also be `None`, in which case a default likelihood function will be - used - loglik_kind : optional - A string that specifies the kind of log-likelihood function specified with - `loglik`. Defaults to `None`. Can be one of the following: - - - `"analytical"`: an analytical (approximation) likelihood function. It is - differentiable and can be used with samplers that requires differentiation. - - `"approx_differentiable"`: a likelihood approximation network (LAN) likelihood - function. It is differentiable and can be used with samplers that requires - differentiation. - - `"blackbox"`: a black box likelihood function. It is typically NOT - differentiable. - - `None`, in which a default will be used. For `ddm` type of models, the default - will be `analytical`. For other models supported, it will be - `approx_differentiable`. If the model is a custom one, a ValueError - will be raised. + model_config + A fully initialised :class:`~hssm.config.BaseModelConfig` instance + (typically :class:`~hssm.config.Config`) produced by the subclass + before calling ``super().__init__``. All likelihood, parameter, and + data information used by :class:`HSSMBase` is drawn from this object, + and it must provide populated ``loglik`` and ``list_params`` fields. p_outlier : optional The fixed lapse probability or the prior distribution of the lapse probability. Defaults to a fixed value of 0.05. When `None`, the lapse probability will not @@ -267,14 +212,8 @@ class HSSMBase(ABC, DataValidatorMixin, MissingDataMixin): def __init__( self, data: pd.DataFrame, - model: SupportedModels | str = "ddm", - choices: list[int] | None = None, + model_config: BaseModelConfig, include: list[dict[str, Any] | Param] | None = None, - model_config: ModelConfig | dict | None = None, - loglik: ( - str | PathLike | Callable | pytensor.graph.Op | type[pm.Distribution] | None - ) = None, - loglik_kind: LoglikKind | None = None, p_outlier: float | dict | bmb.Prior | None = 0.05, lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0), global_formula: str | None = None, @@ -290,14 +229,6 @@ def __init__( initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], **kwargs, ): - # ===== init args for save/load models ===== - self._init_args = { - k: v for k, v in locals().items() if k not in ["self", "kwargs"] - } - if kwargs: - self._init_args.update(kwargs) - # endregion - # ===== Input Data & Configuration ===== self.data = data.copy() self.global_formula = global_formula @@ -305,6 +236,16 @@ def __init__( self.prior_settings = prior_settings self.missing_data_value = -999.0 + # Store a safe default for the constructor-arguments snapshot so that + # pickling / save-load cannot raise AttributeError if a subclass forgets + # to call `_store_init_args(locals(), kwargs)` early. Subclasses are + # still expected to overwrite this with the real snapshot. However, + # do not overwrite if a subclass already set `_init_args` prior to + # calling `super().__init__()` (the subclass may capture its + # constructor args before delegating to the base class). + if not hasattr(self, "_init_args"): + self._init_args: dict[str, Any] = {} + # Set up additional namespace for formula evaluation additional_namespace = transformations_namespace.copy() if extra_namespace is not None: @@ -322,12 +263,8 @@ def __init__( self._initvals: dict[str, Any] = {} self.initval_jitter = initval_jitter - # region ===== Construct a model_config from defaults and user inputs ===== - self.model_config: Config = self._build_model_config( - model, loglik_kind, model_config, choices - ) - self.model_config.update_loglik(loglik) - self.model_config.validate() + # region ===== Store the pre-built config ===== + self.model_config: BaseModelConfig = model_config # endregion # region ===== Set up shortcuts so old code will work ====== @@ -500,120 +437,54 @@ def _validate_fixed_vectors(self) -> None: f"{len(param.prior)}, but data has {len(self.data)} rows." ) - @classmethod - def _build_model_config( - cls, - model: SupportedModels | str, - loglik_kind: LoglikKind | None, - model_config: ModelConfig | dict | None, - choices: list[int] | None, - ) -> Config: - """Build a ModelConfig object from defaults and user inputs. - - Parameters - ---------- - model : SupportedModels | str - The model name. - loglik_kind : LoglikKind | None - The kind of likelihood function. - model_config : ModelConfig | dict | None - User-provided model configuration. - choices : list[int] | None - User-provided choices list. + @classproperty + def supported_models(cls) -> tuple[SupportedModels, ...]: + """Get a tuple of all supported models. Returns ------- - Config - A complete Config object with choices and other settings applied. + tuple[SupportedModels, ...] + A tuple containing all supported model names. """ - # Start with defaults - # get_config_class is provided by Config/RLSSMConfig mixin through MRO - config = cls.get_config_class().from_defaults(model, loglik_kind) # type: ignore[attr-defined] - - # Handle user-provided model_config - if model_config is not None: - # Check if choices already exists in the provided config - has_choices = ( - isinstance(model_config, dict) - and "choices" in model_config - or isinstance(model_config, ModelConfig) - and model_config.choices is not None - ) + return get_args(SupportedModels) - # Handle choices conflict or missing choices - if choices is not None: - if has_choices: - _logger.info( - "choices list provided in both model_config and " - "as an argument directly." - " Using the one provided in model_config. \n" - "We recommend providing choices in model_config." - ) - else: - # Add choices to a copy of the config to avoid mutating input - if isinstance(model_config, dict): - model_config = {**model_config, "choices": choices} - else: # ModelConfig instance - # Create a dict from the ModelConfig and add choices - model_config_dict = { - k: getattr(model_config, k) - for k in model_config.__dataclass_fields__ - if getattr(model_config, k) is not None - } - model_config_dict["choices"] = choices - model_config = model_config_dict - - # Convert dict to ModelConfig if needed and update - final_config = ( - model_config - if isinstance(model_config, ModelConfig) - else ModelConfig(**model_config) - ) - config.update_config(final_config) + @staticmethod + def _store_init_args( + local_vars: dict[str, Any], extra_kwargs: dict[str, Any] + ) -> dict[str, Any]: + """Capture subclass ``__init__`` arguments for save/load serialisation. - # Handle default config (no model_config provided) - else: - # For supported models, defaults already have choices - if model in get_args(SupportedModels): - if choices is not None: - _logger.info( - "Model string is in SupportedModels." - " Ignoring choices arguments." - ) - # For custom models, try to get choices - else: - if choices is not None: - config.update_choices(choices) - elif model in ssms_model_config: - config.update_choices(ssms_model_config[model]["choices"]) - _logger.info( - "choices argument passed as None, " - "but found %s in ssms-simulators. " - "Using choices, from ssm-simulators configs: %s", - model, - ssms_model_config[model]["choices"], - ) + Call this at the very start of a subclass ``__init__`` before any local + variables are assigned, passing ``locals()`` and the ``**kwargs`` dict:: - return config + self._init_args = self._store_init_args(locals(), kwargs) - @classproperty - def supported_models(cls) -> tuple[SupportedModels, ...]: - """Get a tuple of all supported models. + Parameters + ---------- + local_vars + The ``locals()`` snapshot from the subclass ``__init__``. + extra_kwargs + The ``**kwargs`` dict captured by the subclass ``__init__``. Returns ------- - tuple[SupportedModels, ...] - A tuple containing all supported model names. + dict[str, Any] + A mapping of parameter names to their values, suitable for + reconstructing the instance via ``cls(**init_args)``. + + Notes + ----- + The implementation filters out internal names that commonly appear in + ``locals()`` snapshots (for example, ``__class__`` and ``kwargs``) so + that the returned mapping is safe to pass back to the class + constructor during unpickling. """ - return get_args(SupportedModels) - - @classmethod - def _store_init_args(cls, *args, **kwargs): - """Store initialization arguments using signature binding.""" - sig = signature(cls.__init__) - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - return {k: v for k, v in bound_args.arguments.items() if k != "self"} + # Exclude internal names that appear in locals() snapshots and are not + # valid constructor parameters when re-instantiating the class. + exclude_keys = {"self", "kwargs", "__class__"} + result = {k: v for k, v in local_vars.items() if k not in exclude_keys} + result.update(extra_kwargs) + return result def find_MAP(self, **kwargs): """Perform Maximum A Posteriori estimation. @@ -1804,6 +1675,15 @@ def __getstate__(self): A dictionary containing the constructor arguments under the key 'constructor_args'. """ + # Provide a clear error when the initialization snapshot is missing or + # empty. This makes the contract explicit and avoids an AttributeError + # that is easy to miss for subclasses that forget to capture init args. + if not hasattr(self, "_init_args") or not self._init_args: + raise RuntimeError( + "Model state missing initialization snapshot; ensure subclasses " + "call _store_init_args(locals(), kwargs) early in __init__" + ) + state = {"constructor_args": self._init_args} return state diff --git a/src/hssm/config.py b/src/hssm/config.py index 31415df9a..4b75ba3fb 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -20,6 +20,13 @@ if TYPE_CHECKING: from pytensor.tensor.random.op import RandomVariable +import logging + +from ssms.config import model_config as ssms_model_config + +_logger = logging.getLogger("hssm") + + # ====== Centralized RLSSM defaults ===== DEFAULT_SSM_OBSERVED_DATA = ["rt", "response"] DEFAULT_RLSSM_OBSERVED_DATA = ["rt", "response"] @@ -68,16 +75,6 @@ class BaseModelConfig(ABC): # Additional data requirements extra_fields: list[str] | None = None - @classmethod - @abstractmethod - def get_config_class(cls) -> type["BaseModelConfig"]: - """Return the config class for this model type. - - This enables polymorphic config resolution without circular imports. - Each subclass returns itself as the config class. - """ - ... - @abstractmethod def validate(self) -> None: """Validate configuration. Must be implemented by subclasses.""" @@ -88,6 +85,16 @@ def get_defaults(self, param: str) -> Any: """Get default values for a parameter. Must be implemented by subclasses.""" ... + @property + def n_params(self) -> int | None: + """Return the number of parameters.""" + return len(self.list_params) if self.list_params else None + + @property + def n_extra_fields(self) -> int | None: + """Return the number of extra fields.""" + return len(self.extra_fields) if self.extra_fields else None + @dataclass class Config(BaseModelConfig): @@ -102,11 +109,6 @@ def __post_init__(self): if self.loglik_kind is None: raise ValueError("loglik_kind is required for Config") - @classmethod - def get_config_class(cls) -> type["Config"]: - """Return Config as the config class for HSSM models.""" - return Config - @classmethod def from_defaults( cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None @@ -215,7 +217,7 @@ def update_choices(self, choices: tuple[int, ...] | None) -> None: Parameters ---------- - choices : tuple[int, ...] + choices : tuple[int, ...] | None A tuple of choices. """ if choices is None: @@ -275,6 +277,52 @@ def get_defaults( """ return self.default_priors.get(param), self.bounds.get(param) + @classmethod + def _build_model_config( + cls, + model: SupportedModels | str, + loglik_kind: LoglikKind | None, + model_config: ModelConfig | dict | None, + choices: list[int] | tuple[int, ...] | None, + loglik: Any = None, + ) -> Config: + """Build and return a validated Config for standard HSSM models. + + Resolves defaults, normalizes dict/ModelConfig overrides, applies + choices and loglik precedence rules, then validates before returning. + """ + config = cls.from_defaults(model, loglik_kind) + + if model_config is not None: + final_config = _normalize_model_config_with_choices(model_config, choices) + config.update_config(final_config) + + # No model_config provided: apply `choices` when appropriate. + # If caller passed a SupportedModels string, ignore explicit `choices`. + if model in get_args(SupportedModels) and choices is not None: + _logger.info( + "Model string is in SupportedModels. Ignoring choices arguments." + ) + + # If model is not a supported built-in, prefer explicit choices or + # fall back to ssms-simulators lookup when available. + if model not in get_args(SupportedModels): + if choices is not None: + config.update_choices(choices) + elif model in ssms_model_config: + config.update_choices(ssms_model_config[model]["choices"]) + _logger.info( + "choices argument passed as None, " + "but found %s in ssms-simulators. " + "Using choices, from ssm-simulators configs: %s", + model, + ssms_model_config[model]["choices"], + ) + + config.update_loglik(loglik) + config.validate() + return config + @dataclass class RLSSMConfig(BaseModelConfig): @@ -305,11 +353,6 @@ def __post_init__(self): if self.loglik_kind is None: self.loglik_kind = "approx_differentiable" - @classmethod - def get_config_class(cls) -> type["RLSSMConfig"]: - """Return RLSSMConfig as the config class for RLSSM models.""" - return RLSSMConfig - @classmethod def from_defaults( cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None @@ -317,18 +360,8 @@ def from_defaults( """Return the shared Config defaults (delegated to :class:`Config`).""" return Config.from_defaults(model_name, loglik_kind) - @property - def n_params(self) -> int | None: - """Return the number of parameters.""" - return len(self.list_params) if self.list_params else None - - @property - def n_extra_fields(self) -> int | None: - """Return the number of extra fields.""" - return len(self.extra_fields) if self.extra_fields else None - @classmethod - def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> "RLSSMConfig": + def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> RLSSMConfig: """ Create RLSSMConfig from a configuration dictionary. @@ -509,7 +542,7 @@ def to_config(self) -> Config: loglik=self.loglik, ) - def to_model_config(self) -> "ModelConfig": + def to_model_config(self) -> ModelConfig: """Build a :class:`ModelConfig` from this :class:`RLSSMConfig`. All fields are sourced from ``self``; the backend is fixed to ``"jax"`` @@ -530,6 +563,34 @@ def to_model_config(self) -> "ModelConfig": backend="jax", ) + def _build_model_config(self, loglik_op: Any) -> Config: + """Build a validated :class:`Config` for use by :class:`~hssm.rl.rlssm.RLSSM`. + + Converts this :class:`RLSSMConfig` to a :class:`ModelConfig`, then + delegates to :meth:`Config._build_model_config` using the pre-built + differentiable Op as ``loglik``. + + Parameters + ---------- + loglik_op + The differentiable pytensor Op produced by + :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`. + + Returns + ------- + Config + A fully validated :class:`Config` ready to pass to + :meth:`~hssm.base.HSSMBase.__init__`. + """ + mc = self.to_model_config() + return Config._build_model_config( + self.model_name, + "approx_differentiable", + mc, + None, + loglik_op, + ) + @dataclass class ModelConfig: @@ -543,3 +604,46 @@ class ModelConfig: backend: Literal["jax", "pytensor"] | None = None rv: RandomVariable | None = None extra_fields: list[str] | None = None + + +def _normalize_model_config_with_choices( + model_config: "ModelConfig" | dict[str, Any], + choices: list[int] | tuple[int, ...] | None, +) -> "ModelConfig": + """Normalize a user-supplied model_config and apply choices. + + Returns a fresh :class:`ModelConfig` instance and does not mutate the + caller's objects. If both ``model_config`` and ``choices`` are provided + and ``model_config`` already contains ``choices``, the value from + ``model_config`` wins (and a log entry is emitted). + """ + # Normalize input to a mutable dict so we can coerce and avoid mutating + # the caller's objects. Build a fresh ModelConfig from that dict. + if isinstance(model_config, ModelConfig): + mc: dict[str, Any] = { + k: getattr(model_config, k) for k in model_config.__dataclass_fields__ + } + else: + mc = model_config.copy() + + # Coerce any existing choices on the input to a tuple for immutability + if mc.get("choices") is not None: + mc["choices"] = tuple(mc["choices"]) + + # If caller didn't provide an explicit `choices` argument, return the + # normalized ModelConfig built from the input (fresh instance). + if choices is None: + return ModelConfig(**{k: v for k, v in mc.items() if v is not None}) + + # Caller provided choices; prefer the one embedded in model_config if + # present, otherwise apply the provided value (coerced to tuple). + if mc.get("choices") is not None: + _logger.info( + "choices list provided in both model_config and " + "as an argument directly. Using the one provided in " + "model_config. We recommend providing choices in model_config." + ) + return ModelConfig(**{k: v for k, v in mc.items() if v is not None}) + + mc["choices"] = tuple(choices) + return ModelConfig(**{k: v for k, v in mc.items() if v is not None}) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 45224cd20..e86b4f7e5 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -9,13 +9,18 @@ import logging from copy import deepcopy from inspect import isclass +from os import PathLike from typing import TYPE_CHECKING, Any, Callable, Literal from typing import cast as typing_cast +import bambi as bmb import numpy as np +import pandas as pd import pymc as pm +from hssm._types import LoglikKind, SupportedModels from hssm.defaults import ( + INITVAL_JITTER_SETTINGS, MissingDataNetwork, missing_data_networks_suffix, ) @@ -30,11 +35,9 @@ ) from .base import HSSMBase -from .config import Config +from .config import Config, ModelConfig if TYPE_CHECKING: - from os import PathLike - from pytensor.graph.op import Op _logger = logging.getLogger("hssm") @@ -75,7 +78,7 @@ def __get__(self, instance, owner): # noqa: D105 return self.fget(owner) -class HSSM(HSSMBase, Config): +class HSSM(HSSMBase): """The basic Hierarchical Sequential Sampling Model (HSSM) class. Parameters @@ -101,9 +104,12 @@ class HSSM(HSSMBase, Config): model. If left unspecified, defaults will be used for all parameter specifications. Defaults to None. model_config : optional - A dictionary containing the model configuration information. If None is - provided, defaults will be used if there are any. Defaults to None. - Fields for this `dict` are usually: + A :class:`~hssm.config.BaseModelConfig` / :class:`~hssm.config.Config` + instance or a ``dict`` with model configuration information. The + constructor accepts a typed ``ModelConfig`` or a plain ``dict``; when a + ``dict`` is provided the library will build a typed :class:`Config` + via the factory function. If ``None`` is provided, defaults will be + used where available. Fields for this config are usually: - `"list_params"`: a list of parameters indicating the parameters of the model. The order in which the parameters are specified in this list is important. @@ -248,6 +254,56 @@ class HSSM(HSSMBase, Config): The jitter value for the initial values. """ + def __init__( + self, + data: pd.DataFrame, + model: SupportedModels | str = "ddm", + choices: list[int] | None = None, + include: list[dict[str, Any] | Any] | None = None, + model_config: ModelConfig | dict | None = None, + loglik: ( + str | PathLike | Callable | pm.Distribution | type[pm.Distribution] | None + ) = None, + loglik_kind: LoglikKind | None = None, + p_outlier: float | dict | bmb.Prior | None = 0.05, + lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0), + global_formula: str | None = None, + link_settings: Literal["log_logit"] | None = None, + prior_settings: Literal["safe"] | None = "safe", + extra_namespace: dict[str, Any] | None = None, + missing_data: bool | float = False, + deadline: bool | str = False, + loglik_missing_data: (str | PathLike | Callable | None) = None, + process_initvals: bool = True, + initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], + **kwargs: Any, + ) -> None: + # ===== save/load serialisation ===== + self._init_args = self._store_init_args(locals(), kwargs) + + # Build typed Config via factory + config = Config._build_model_config( + model, loglik_kind, model_config, choices, loglik + ) + + super().__init__( + data=data, + model_config=config, + include=include, + p_outlier=p_outlier, + lapse=lapse, + global_formula=global_formula, + link_settings=link_settings, + prior_settings=prior_settings, + extra_namespace=extra_namespace, + missing_data=missing_data, + deadline=deadline, + loglik_missing_data=loglik_missing_data, + process_initvals=process_initvals, + initval_jitter=initval_jitter, + **kwargs, + ) + def _make_model_distribution(self) -> type[pm.Distribution]: """Make a pm.Distribution for the model.""" ### Logic for different types of likelihoods: @@ -366,11 +422,15 @@ def _make_model_distribution(self) -> type[pm.Distribution]: if isinstance(param.prior, np.ndarray) } - assert self.list_params is not None, "list_params should be set" + # Use the typed `model_config` attributes directly + _list_params = self.model_config.list_params + assert _list_params is not None, "list_params should be set" # for type checker + rv_name = getattr(self.model_config, "rv", None) or self.model_config.model_name + return make_distribution( - rv=self.model_config.rv or self.model_name, + rv=rv_name, loglik=self.loglik, - list_params=self.list_params, + list_params=_list_params, bounds=self.bounds, lapse=self.lapse, extra_fields=( diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 339dd870c..c76a0a811 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -25,7 +25,8 @@ if TYPE_CHECKING: from pytensor.graph import Op -from hssm.config import RLSSMConfig + +from hssm.config import Config, RLSSMConfig from hssm.defaults import ( INITVAL_JITTER_SETTINGS, ) @@ -36,7 +37,7 @@ from ..base import HSSMBase -class RLSSM(HSSMBase, RLSSMConfig): +class RLSSM(HSSMBase): """Reinforcement Learning Sequential Sampling Model. Combines a reinforcement learning (RL) process with a sequential sampling @@ -46,7 +47,7 @@ class RLSSM(HSSMBase, RLSSMConfig): The likelihood is built via :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op` from the annotated - SSM function stored in *rlssm_config.ssm_logp_func*. This produces a + SSM function stored in *model_config.ssm_logp_func*. This produces a differentiable pytensor ``Op`` that is passed directly to :func:`~hssm.distribution_utils.make_distribution`, superseding the ``loglik`` / ``loglik_kind`` dispatching used by :class:`~hssm.hssm.HSSM`. @@ -55,12 +56,12 @@ class RLSSM(HSSMBase, RLSSMConfig): ---------- data : pd.DataFrame Trial-level data. Must contain at least the response columns - specified in *rlssm_config* (typically ``"rt"`` and ``"response"``), + specified in *model_config* (typically ``"rt"`` and ``"response"``), a participant identifier column (default ``"participant_id"``), and - any extra fields listed in *rlssm_config.extra_fields*. + any extra fields listed in *model_config.extra_fields*. The data **must** form a balanced panel: every participant must have the same number of trials. - rlssm_config : RLSSMConfig + model_config : RLSSMConfig Full configuration for the RLSSM model. Must have ``ssm_logp_func`` set to the annotated JAX SSM log-likelihood function. participant_col : str, optional @@ -96,18 +97,18 @@ class RLSSM(HSSMBase, RLSSMConfig): Attributes ---------- - _rlssm_config : RLSSMConfig + config : RLSSMConfig The RLSSM configuration object. - _n_participants : int + n_participants : int Number of participants inferred from *data*. - _n_trials : int + n_trials : int Number of trials per participant inferred from *data*. """ def __init__( self, data: pd.DataFrame, - rlssm_config: RLSSMConfig, + model_config: RLSSMConfig, participant_col: str = "participant_id", include: list[dict[str, Any] | Any] | None = None, p_outlier: float | dict | bmb.Prior | None = 0.05, @@ -122,8 +123,11 @@ def __init__( initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], **kwargs: Any, ) -> None: + # ===== save/load serialisation ===== + self._init_args = self._store_init_args(locals(), kwargs) + # Validate config (ensures ssm_logp_func is present, etc.) - rlssm_config.validate() + model_config.validate() # RLSSM reshapes rows into (n_participants, n_trials, ...) by position, # so _rearrange_data (which moves missing/deadline rows to the front) @@ -152,9 +156,9 @@ def __init__( # Store RL-specific state on self BEFORE super().__init__() so that # _make_model_distribution() (called from super) can access them. - self._rlssm_config = rlssm_config - self._n_participants = n_participants - self._n_trials = n_trials + self.config = model_config + self.n_participants = n_participants + self.n_trials = n_trials # Build the differentiable pytensor Op from the annotated SSM function. # This Op supersedes the loglik/loglik_kind workflow: it is passed as @@ -166,28 +170,23 @@ def __init__( # "p_outlier" to self.list_params, and that mutation must NOT be visible # to the Op's _validate_args_length check at sampling time. loglik_op = make_rl_logp_op( - ssm_logp_func=rlssm_config.ssm_logp_func, + ssm_logp_func=model_config.ssm_logp_func, n_participants=n_participants, n_trials=n_trials, - data_cols=list(rlssm_config.response), # type: ignore[arg-type] - list_params=list(rlssm_config.list_params), # type: ignore[arg-type] - extra_fields=list(rlssm_config.extra_fields or []), + data_cols=list(model_config.response), # type: ignore[arg-type] + list_params=list(model_config.list_params), # type: ignore[arg-type] + extra_fields=list(model_config.extra_fields or []), ) - # Delegate ModelConfig construction to RLSSMConfig, which already owns - # all the required fields (response, list_params, choices, bounds, …). - mc = rlssm_config.to_model_config() + # Build a typed Config instance via RLSSMConfig's own factory method. + # The differentiable Op is passed so Config.validate() is satisfied; + # loglik_kind="approx_differentiable" reflects that the Op has gradients. + config = model_config._build_model_config(loglik_op) super().__init__( data=data, - model=rlssm_config.model_name, + model_config=config, include=include, - model_config=mc, - # Pass the Op as loglik so Config.validate() is satisfied. - # loglik_kind="approx_differentiable" reflects that the Op is - # differentiable (gradients flow through its VJP). - loglik=loglik_op, - loglik_kind="approx_differentiable", p_outlier=p_outlier, lapse=lapse, link_settings=link_settings, @@ -208,7 +207,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]: through :func:`~hssm.distribution_utils.make_likelihood_callable`. Instead it uses ``self.loglik`` directly — the differentiable pytensor ``Op`` built in :meth:`__init__` from - ``self._rlssm_config.ssm_logp_func``. + ``self.config.ssm_logp_func``. The Op already handles: - The RL learning rule (computing trial-wise intermediate parameters). @@ -219,27 +218,37 @@ def _make_model_distribution(self) -> type[pm.Distribution]: RLSSM and ``missing_data`` / ``deadline`` are rejected in ``__init__`` before this method is ever reached. """ - # Build params_is_trialwise in the same order as self.list_params so the - # length always matches the list_params= argument passed to make_distribution. - # p_outlier is a scalar mixture weight (not trialwise); every other RLSSM - # parameter is trialwise (the Op receives one value per trial). - assert self.list_params is not None, "list_params should be set by HSSMBase" - params_is_trialwise = [name != "p_outlier" for name in self.list_params] + list_params = self.model_config.list_params + assert list_params is not None, "model_config.list_params must be set" + assert isinstance(list_params, list), ( + "model_config.list_params must be a list" + ) # for type checker + # p_outlier is a scalar mixture weight (not trialwise); every other + # RLSSM parameter is trialwise (the Op receives one value per trial). + params_is_trialwise = [name != "p_outlier" for name in list_params] + + extra_fields = self.model_config.extra_fields or [] extra_fields_data = ( None - if not self.extra_fields - else [self.data[field].to_numpy(copy=True) for field in self.extra_fields] + if not extra_fields + else [self.data[field].to_numpy(copy=True) for field in extra_fields] ) - # self.loglik was set to the pytensor Op built in __init__; cast to - # narrow the inherited union type so make_distribution's type-checker - # accepts it without a runtime penalty. - loglik_op = cast("Callable[..., Any] | Op", self.loglik) + # The differentiable pytensor Op was stored on the validated model_config + # during __init__ as its `loglik`; ensure it's present and cast for typing. + assert self.model_config.loglik is not None, "model_config.loglik must be set" + loglik_op = cast("Callable[..., Any] | Op", self.model_config.loglik) + + # `model_config` is typed as BaseModelConfig on the base class; cast + # to `Config` here so static checkers understand `rv` exists. + cfg = cast("Config", self.model_config) + rv_name = cfg.rv or cfg.model_name + return make_distribution( - rv=self.model_name, + rv=rv_name, loglik=loglik_op, - list_params=self.list_params, + list_params=list_params, bounds=self.bounds, lapse=self.lapse, extra_fields=extra_fields_data, diff --git a/tests/test_config.py b/tests/test_config.py index ba8429d06..4094c354d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,8 +1,11 @@ -import numpy as np +import logging + import pytest +import numpy as np -import hssm from hssm.config import Config, ModelConfig +import hssm + hssm.set_floatX("float32") @@ -79,3 +82,60 @@ def test_update_config(): assert v_prior.name == "Normal" assert v_bounds == (-np.inf, np.inf) + + +class TestConfigBuildModelConfigExtraLogic: + def test_build_model_config_dict_with_choices_conflict(self, caplog): + # model 'ddm' has defaults in hssm.defaults; use a minimal dict override + model_config = { + "response": ("rt", "response"), + "list_params": ["v", "a"], + "choices": (0, 1), + } + # provide a different choices argument — should log that model_config wins + with caplog.at_level(logging.INFO): + cfg = Config._build_model_config("ddm", None, model_config, choices=[1, 0]) + + assert isinstance(cfg, Config) + assert "choices list provided in both model_config" in caplog.text + + def test_build_model_config_modelconfig_adds_choices(self): + # Create a ModelConfig without choices and pass choices argument + mc = ModelConfig(response=("rt", "response"), list_params=["v"], choices=None) + cfg = Config._build_model_config("ddm", None, mc, choices=(0, 1)) + # choices should be applied to resulting Config + assert cfg.choices == (0, 1) + + def test_build_model_config_uses_ssms_model_config(self, monkeypatch): + # High-level view of the test: ensures that when a model name is not in the built-in + # SupportedModels and no choices argument is passed, _build_model_config will consult + # the external ssms_model_config registry and use its defaults (here, the choices tuple). + # The monkeypatch fixture isolates the change and will be undone after the test. + + # Simulate an external ssms_model_config entry for a model not in SupportedModels + fake_model = "external_ssm" + fake_choices = (2, 3) + + # Monkeypatch the ssms_model_config mapping in the module + import hssm.config as cfgmod + + # Emulate an external package registering defaults for external_ssm. + # Ensures `_build_model_config` will consult `ssms_model_config` + # when the model name isn't in SupportedModels. + monkeypatch.setitem( + cfgmod.ssms_model_config, fake_model, {"choices": fake_choices} + ) + + # Build config with model not in SupportedModels and no choices arg. + # Provide a minimal ModelConfig and a dummy `loglik` so + # `Config.validate()` runs (loglik is required) while still + # exercising the ssms-simulators choices fallback. + mc = ModelConfig(response=("rt", "response"), list_params=["v"], choices=None) + result = Config._build_model_config( + fake_model, + "analytical", + mc, + choices=None, + loglik=(lambda *a, **k: None), # required so Config.validate() passes + ) + assert result.choices == fake_choices diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 8309bf013..ffc58b38c 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -123,9 +123,9 @@ def test_custom_model(data_ddm): loglik_kind="analytical", ) - assert model.model_name == "custom" - assert model.loglik_kind == "analytical" - assert model.list_params == ["v", "a", "z", "t", "p_outlier"] + assert model.model_config.model_name == "custom" + assert model.model_config.loglik_kind == "analytical" + assert model.model_config.list_params == ["v", "a", "z", "t", "p_outlier"] @pytest.mark.slow diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index 61937f168..98b341816 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -104,53 +104,47 @@ def rlssm_config() -> RLSSMConfig: # --------------------------------------------------------------------------- -def test_rlssm_init(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_init(rldm_data, rlssm_config) -> None: """Basic RLSSM initialisation should succeed and return an RLSSM instance.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) assert isinstance(model, RLSSM) - assert model.model_name == "rldm_test" + assert model.model_config.model_name == "rldm_test" -def test_rlssm_panel_attrs(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: - """_n_participants and _n_trials should match the fixture data structure.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) +def test_rlssm_panel_attrs(rldm_data, rlssm_config) -> None: + """n_participants and n_trials should match the fixture data structure.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) n_participants = rldm_data["participant_id"].nunique() n_trials = len(rldm_data) // n_participants - assert model._n_participants == n_participants - assert model._n_trials == n_trials + assert model.n_participants == n_participants + assert model.n_trials == n_trials -def test_rlssm_params_keys(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_params_keys(rldm_data, rlssm_config) -> None: """model.params should contain exactly list_params + p_outlier.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) expected = set(rlssm_config.list_params) | {"p_outlier"} assert set(model.params.keys()) == expected -def test_rlssm_unbalanced_raises( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_unbalanced_raises(rldm_data, rlssm_config) -> None: """Dropping one row should make the panel unbalanced → ValueError.""" unbalanced = rldm_data.iloc[:-1].copy() with pytest.raises(ValueError, match="balanced panels"): - RLSSM(data=unbalanced, rlssm_config=rlssm_config) + RLSSM(data=unbalanced, model_config=rlssm_config) -def test_rlssm_nan_participant_id_raises( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_nan_participant_id_raises(rldm_data, rlssm_config) -> None: """NaN in participant_id column should raise ValueError before groupby silently drops rows.""" nan_data = rldm_data.copy() nan_data.loc[nan_data.index[0], "participant_id"] = float("nan") with pytest.raises(ValueError, match="NaN"): - RLSSM(data=nan_data, rlssm_config=rlssm_config) + RLSSM(data=nan_data, model_config=rlssm_config) -def test_rlssm_missing_ssm_logp_func_raises( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_missing_ssm_logp_func_raises(rldm_data, rlssm_config) -> None: """RLSSMConfig without ssm_logp_func should raise ValueError on init.""" bad_config = RLSSMConfig( model_name="rldm_bad", @@ -168,12 +162,10 @@ def test_rlssm_missing_ssm_logp_func_raises( # ssm_logp_func intentionally omitted → defaults to None ) with pytest.raises(ValueError, match="ssm_logp_func"): - RLSSM(data=rldm_data, rlssm_config=bad_config) + RLSSM(data=rldm_data, model_config=bad_config) -def test_rlssm_unannotated_ssm_logp_func_raises( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_unannotated_ssm_logp_func_raises(rldm_data, rlssm_config) -> None: """A plain callable without @annotate_function attrs should raise ValueError.""" bad_config = RLSSMConfig( model_name="rldm_bad", @@ -191,23 +183,19 @@ def test_rlssm_unannotated_ssm_logp_func_raises( ssm_logp_func=lambda x: x, # callable but no .inputs/.outputs/.computed ) with pytest.raises(ValueError, match="annotate_function"): - RLSSM(data=rldm_data, rlssm_config=bad_config) + RLSSM(data=rldm_data, model_config=bad_config) -def test_rlssm_missing_data_raises( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_missing_data_raises(rldm_data, rlssm_config) -> None: """Passing missing_data!=False should raise ValueError with 'missing_data' in msg.""" with pytest.raises(ValueError, match="missing_data"): - RLSSM(data=rldm_data, rlssm_config=rlssm_config, missing_data=True) + RLSSM(data=rldm_data, model_config=rlssm_config, missing_data=True) -def test_rlssm_deadline_raises( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_deadline_raises(rldm_data, rlssm_config) -> None: """Passing deadline!=False should raise ValueError with 'deadline' in msg.""" with pytest.raises(ValueError, match="deadline"): - RLSSM(data=rldm_data, rlssm_config=rlssm_config, deadline=True) + RLSSM(data=rldm_data, model_config=rlssm_config, deadline=True) # --------------------------------------------------------------------------- @@ -215,43 +203,43 @@ def test_rlssm_deadline_raises( # --------------------------------------------------------------------------- -def test_rlssm_params_is_trialwise_aligned( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_params_is_trialwise_aligned(rldm_data, rlssm_config) -> None: """params_is_trialwise must align with list_params (same length, p_outlier=False).""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) - assert model.list_params is not None - params_is_trialwise = [name != "p_outlier" for name in model.list_params] - assert len(params_is_trialwise) == len(model.list_params) - for name, is_tw in zip(model.list_params, params_is_trialwise): + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.model_config.list_params is not None + params_is_trialwise = [ + name != "p_outlier" for name in model.model_config.list_params + ] + assert len(params_is_trialwise) == len(model.model_config.list_params) + for name, is_tw in zip(model.model_config.list_params, params_is_trialwise): if name == "p_outlier": assert not is_tw, "p_outlier must be non-trialwise" else: assert is_tw, f"{name} must be trialwise" -def test_rlssm_get_prefix(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_get_prefix(rldm_data, rlssm_config) -> None: """_get_prefix must use token-based matching, not substring search. - 'rl_alpha_Intercept' → 'rl_alpha' (underscore-containing RL param) - 'p_outlier_log__' → 'p_outlier' (lapse param via token loop, not substring) - 'a_Intercept' → 'a' (single-token standard param) """ - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) assert model._get_prefix("rl_alpha_Intercept") == "rl_alpha" assert model._get_prefix("p_outlier_log__") == "p_outlier" assert model._get_prefix("a_Intercept") == "a" -def test_rlssm_no_lapse(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_no_lapse(rldm_data, rlssm_config) -> None: """Setting p_outlier=None should remove p_outlier from params.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config, p_outlier=None) + model = RLSSM(data=rldm_data, model_config=rlssm_config, p_outlier=None) assert "p_outlier" not in model.params -def test_rlssm_model_built(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_model_built(rldm_data, rlssm_config) -> None: """The bambi model should be built and the computed param 'v' absent from params.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) assert model.model is not None # rl_alpha is a free (sampled) parameter assert "rl_alpha" in model.params @@ -259,9 +247,7 @@ def test_rlssm_model_built(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) - assert "v" not in model.params -def test_rlssm_extra_fields_are_copies( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_extra_fields_are_copies(rldm_data, rlssm_config) -> None: """extra_fields passed to make_distribution must be independent numpy copies. to_numpy(copy=True) should return a new buffer; if it returned a view, @@ -271,7 +257,7 @@ def test_rlssm_extra_fields_are_copies( from hssm.distribution_utils import make_distribution as real_make_distribution - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) captured: dict = {} def capturing_make_distribution(*args, **kwargs): @@ -292,9 +278,9 @@ def capturing_make_distribution(*args, **kwargs): ) -def test_rlssm_pymc_model(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_pymc_model(rldm_data, rlssm_config) -> None: """pymc_model should be accessible after model construction.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) assert model.pymc_model is not None @@ -304,8 +290,10 @@ def test_rlssm_pymc_model(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> @pytest.mark.slow -def test_rlssm_sample_smoke(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_sample_smoke(rldm_data, rlssm_config) -> None: """Minimal sampling run should return an InferenceData object.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) - trace = model.sample(draws=2, tune=2, chains=1, cores=1) + model = RLSSM(data=rldm_data, model_config=rlssm_config) + trace = model.sample( + draws=4, tune=50, chains=1, cores=1, sampler="numpyro", target_accept=0.9 + ) assert trace is not None diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 143ff1176..46859da47 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -313,7 +313,7 @@ def test_to_config_cases( expected_default_priors, raises, ): - rlssm_config = RLSSMConfig( + model_config = RLSSMConfig( model_name="test_model", list_params=list_params, params_default=params_default, @@ -327,14 +327,14 @@ def test_to_config_cases( ) if raises: with pytest.raises(raises): - rlssm_config.to_config() + model_config.to_config() else: - config = rlssm_config.to_config() + config = model_config.to_config() assert config.backend == expected_backend assert config.default_priors == expected_default_priors def test_to_config(self): - rlssm_config = RLSSMConfig( + model_config = RLSSMConfig( model_name="rlwm", description="RLWM model", list_params=["alpha", "beta", "v", "a"], @@ -354,7 +354,7 @@ def test_to_config(self): learning_process_loglik_kind="blackbox", learning_process={}, ) - config = rlssm_config.to_config() + config = model_config.to_config() assert isinstance(config, Config) assert config.model_name == "rlwm" assert config.description == "RLWM model" @@ -378,7 +378,7 @@ def test_to_config(self): } def test_to_config_defaults_backend(self): - rlssm_config = RLSSMConfig( + model_config = RLSSMConfig( model_name="test_model", list_params=["alpha"], params_default=[0.5], @@ -389,11 +389,11 @@ def test_to_config_defaults_backend(self): learning_process_loglik_kind="blackbox", learning_process={}, ) - config = rlssm_config.to_config() + config = model_config.to_config() assert config.backend == "jax" def test_to_config_no_defaults(self): - rlssm_config = RLSSMConfig( + model_config = RLSSMConfig( model_name="test_model", list_params=["alpha", "beta"], params_default=[], @@ -404,11 +404,11 @@ def test_to_config_no_defaults(self): learning_process_loglik_kind="blackbox", learning_process={}, ) - config = rlssm_config.to_config() + config = model_config.to_config() assert config.default_priors == {} def test_to_config_mismatched_defaults_length(self): - rlssm_config = RLSSMConfig( + model_config = RLSSMConfig( model_name="test_model", list_params=["alpha", "beta", "gamma"], params_default=[0.5, 0.3], @@ -423,7 +423,7 @@ def test_to_config_mismatched_defaults_length(self): ValueError, match=r"params_default length \(2\) doesn't match list_params length \(3\)", ): - rlssm_config.to_config() + model_config.to_config() class TestRLSSMConfigLearningProcess: diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 489dca123..614427f70 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -13,7 +13,9 @@ def compare_hssm_class_attributes(model_a, model_b): b = np.array([type(v) for k, v in model_b._init_args.items()]) assert (a == b).all(), "Init arg types not the same" assert (model_a.data).equals(model_b.data), "Data not the same" - assert model_a.model_name == model_b.model_name, "Model name not the same" + assert model_a.model_config.model_name == model_b.model_config.model_name, ( + "Model name not the same" + ) assert model_a.pymc_model._repr_latex_() == model_b.pymc_model._repr_latex_(), ( "Latex representation of model not the same" )