From 20ddc4c41712bb8faa73b3658e702506a7345650 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 12:33:37 -0500 Subject: [PATCH 01/32] Add ssm_logp_func to RLSSMConfig and update validation tests --- src/hssm/config.py | 16 ++++++++++++++++ tests/test_rlssm_config.py | 6 ++++++ 2 files changed, 22 insertions(+) diff --git a/src/hssm/config.py b/src/hssm/config.py index 1ae0ea2a0..64e33e932 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -266,6 +266,12 @@ class RLSSMConfig(BaseModelConfig): This configuration class is designed for models that combine reinforcement learning processes with sequential sampling decision models (RLSSM). + + The ``ssm_logp_func`` field holds the fully annotated JAX SSM log-likelihood + function (an :class:`AnnotatedFunction`) that is passed directly to + ``make_rl_logp_op``. It supersedes the ``loglik`` / ``loglik_kind`` workflow + used by :class:`HSSM`: the Op is built from ``ssm_logp_func`` and therefore + no ``loglik`` callable needs to be provided. """ decision_process_loglik_kind: str = field(kw_only=True) @@ -273,6 +279,10 @@ class RLSSMConfig(BaseModelConfig): params_default: list[float] = field(kw_only=True) decision_process: str | ModelConfig = field(kw_only=True) learning_process: dict[str, Any] = field(kw_only=True) + # The fully annotated SSM log-likelihood function used by make_rl_logp_op. + # Type is Any to avoid a hard dependency on the AnnotatedFunction Protocol at + # import time; validated at runtime in validate(). + ssm_logp_func: Any = field(default=None, kw_only=True) def __post_init__(self): """Set default loglik_kind for RLSSM models if not provided.""" @@ -332,6 +342,7 @@ def from_rlssm_dict(cls, model_name: str, config_dict: dict[str, Any]): params_default=config_dict["params_default"], decision_process=config_dict["decision_process"], learning_process=config_dict["learning_process"], + ssm_logp_func=config_dict.get("ssm_logp_func"), bounds=config_dict.get("bounds", {}), response=config_dict["response"], choices=config_dict["choices"], @@ -355,6 +366,11 @@ def validate(self) -> None: raise ValueError("Please provide `choices` in the configuration.") if self.decision_process is None: raise ValueError("Please specify a `decision_process`.") + if self.ssm_logp_func is None: + raise ValueError( + "Please provide `ssm_logp_func`: the fully annotated JAX SSM " + "log-likelihood function required by `make_rl_logp_op`." + ) # Validate parameter defaults consistency if self.params_default and self.list_params: diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 046ee6e4e..9ae8257c5 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -51,6 +51,9 @@ def create_config_dict( # region fixtures and helpers @pytest.fixture def valid_rlssmconfig_kwargs(): + def _dummy_ssm_logp_func(x): + return x + return dict( model_name="test_model", list_params=["alpha", "beta"], @@ -62,6 +65,7 @@ def valid_rlssmconfig_kwargs(): decision_process_loglik_kind="analytical", learning_process_loglik_kind="blackbox", learning_process={}, + ssm_logp_func=_dummy_ssm_logp_func, ) @@ -166,6 +170,7 @@ class TestRLSSMConfigValidation: ("list_params", None, "Please provide `list_params"), ("choices", None, "Please provide `choices"), ("decision_process", None, "Please specify a `decision_process"), + ("ssm_logp_func", None, "Please provide `ssm_logp_func"), ], ) def test_validate_missing_fields( @@ -212,6 +217,7 @@ def test_validate_params_default_mismatch(self): decision_process_loglik_kind="analytical", learning_process_loglik_kind="blackbox", learning_process={}, + ssm_logp_func=lambda x: x, ) with pytest.raises( ValueError, From d97dcee352d7b894e420360e36abb4c27e1466a2 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 12:44:26 -0500 Subject: [PATCH 02/32] Add RLSSM model and utilities for reinforcement learning integration --- src/hssm/__init__.py | 5 +- src/hssm/rl/__init__.py | 18 +++ src/hssm/rl/rlssm.py | 266 ++++++++++++++++++++++++++++++++++++++++ src/hssm/rl/utils.py | 49 ++++++++ 4 files changed, 337 insertions(+), 1 deletion(-) create mode 100644 src/hssm/rl/__init__.py create mode 100644 src/hssm/rl/rlssm.py create mode 100644 src/hssm/rl/utils.py diff --git a/src/hssm/__init__.py b/src/hssm/__init__.py index 60dd71020..72aa14ba3 100644 --- a/src/hssm/__init__.py +++ b/src/hssm/__init__.py @@ -11,7 +11,7 @@ import logging import sys -from .config import ModelConfig +from .config import ModelConfig, RLSSMConfig from .datasets import load_data from .defaults import show_defaults from .hssm import HSSM @@ -19,6 +19,7 @@ from .param import UserParam as Param from .prior import Prior from .register import register_model +from .rl import RLSSM from .simulator import simulate_data from .utils import check_data_for_rl, set_floatX @@ -31,6 +32,8 @@ __all__ = [ "HSSM", + "RLSSM", + "RLSSMConfig", "Link", "load_data", "ModelConfig", diff --git a/src/hssm/rl/__init__.py b/src/hssm/rl/__init__.py new file mode 100644 index 000000000..09eb646ae --- /dev/null +++ b/src/hssm/rl/__init__.py @@ -0,0 +1,18 @@ +"""Reinforcement learning extensions for HSSM. + +This sub-package provides: + +- :class:`~hssm.rl.rlssm.RLSSM` — the RL + SSM model class. +- :func:`~hssm.rl.utils.validate_balanced_panel` — panel-balance utility. +- :mod:`hssm.rl.likelihoods` — log-likelihood builders + (:func:`~hssm.rl.likelihoods.builder.make_rl_logp_func`, + :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`). +""" + +from .rlssm import RLSSM +from .utils import validate_balanced_panel + +__all__ = [ + "RLSSM", + "validate_balanced_panel", +] diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py new file mode 100644 index 000000000..dce195f3b --- /dev/null +++ b/src/hssm/rl/rlssm.py @@ -0,0 +1,266 @@ +"""RLSSM: Reinforcement Learning Sequential Sampling Model. + +This module defines the :class:`RLSSM` class, a subclass of :class:`HSSMBase` +for models that couple a reinforcement learning (RL) learning process with a +sequential sampling decision model (SSM). + +The key difference from :class:`HSSM` is the likelihood: + - ``HSSM`` wraps an analytical / ONNX / blackbox callable via + :func:`~hssm.distribution_utils.make_likelihood_callable`. + - ``RLSSM`` builds a differentiable pytensor ``Op`` directly from an + :class:`~hssm.rl.likelihoods.builder.AnnotatedFunction` via + :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`, which internally + handles the RL learning rule and per-participant trial structure. + This Op is then passed straight to + :func:`~hssm.distribution_utils.make_distribution`, bypassing the + standard ``loglik`` / ``loglik_kind`` wrapping pipeline. +""" + +import logging +from copy import deepcopy +from typing import Any, Callable, Literal + +import bambi as bmb +import pandas as pd +import pymc as pm + +from hssm.config import ModelConfig, RLSSMConfig +from hssm.defaults import ( + INITVAL_JITTER_SETTINGS, + MissingDataNetwork, +) +from hssm.distribution_utils import make_distribution +from hssm.rl.likelihoods.builder import make_rl_logp_op +from hssm.rl.utils import validate_balanced_panel +from hssm.utils import _rearrange_data + +from ..base import HSSMBase + +_logger = logging.getLogger("hssm") + + +class RLSSM(HSSMBase): + """Reinforcement Learning Sequential Sampling Model. + + Combines a reinforcement learning (RL) process with a sequential sampling + model (SSM) inside a single differentiable likelihood. The RL component + computes trial-wise intermediate parameters (e.g., drift rates) from the + learning history, which are then fed into the SSM log-likelihood. + + The likelihood is built via + :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op` from the annotated + SSM function stored in *rlssm_config.ssm_logp_func*. This produces a + differentiable pytensor ``Op`` that is passed directly to + :func:`~hssm.distribution_utils.make_distribution`, superseding the + ``loglik`` / ``loglik_kind`` dispatching used by :class:`~hssm.hssm.HSSM`. + + Parameters + ---------- + data : pd.DataFrame + Trial-level data. Must contain at least the response columns + specified in *rlssm_config* (typically ``"rt"`` and ``"response"``), + a participant identifier column (default ``"participant_id"``), and + any extra fields listed in *rlssm_config.extra_fields*. + The data **must** form a balanced panel: every participant must have + the same number of trials. + rlssm_config : RLSSMConfig + Full configuration for the RLSSM model. Must have ``ssm_logp_func`` + set to the annotated JAX SSM log-likelihood function. + participant_col : str, optional + Name of the column that uniquely identifies participants. + Used to infer ``n_participants`` and ``n_trials`` from *data*. + Defaults to ``"participant_id"``. + include : list, optional + Parameter specifications forwarded to :class:`~hssm.base.HSSMBase`. + p_outlier : float | dict | bmb.Prior | None, optional + Lapse probability specification. Defaults to ``0.05``. + lapse : dict | bmb.Prior | None, optional + Lapse distribution. Defaults to ``Uniform(0, 20)``. + link_settings : Literal["log_logit"] | None, optional + Link-function preset. Defaults to ``None``. + prior_settings : Literal["safe"] | None, optional + Prior preset. Defaults to ``"safe"``. + extra_namespace : dict | None, optional + Extra variables for formula evaluation. Defaults to ``None``. + missing_data : bool | float, optional + Whether to handle missing RT data coded as ``-999.0``. + Defaults to ``False``. + deadline : bool | str, optional + Whether to handle deadline data. Defaults to ``False``. + loglik_missing_data : Callable | None, optional + Custom likelihood for missing observations. Defaults to ``None``. + process_initvals : bool, optional + Whether to post-process initial values. Defaults to ``True``. + initval_jitter : float, optional + Jitter magnitude for initial values. + Defaults to :data:`~hssm.defaults.INITVAL_JITTER_SETTINGS` epsilon. + **kwargs + Additional keyword arguments forwarded to :class:`bmb.Model`. + + Attributes + ---------- + _rlssm_config : RLSSMConfig + The RLSSM configuration object. + _n_participants : int + Number of participants inferred from *data*. + _n_trials : int + Number of trials per participant inferred from *data*. + """ + + def __init__( + self, + data: pd.DataFrame, + rlssm_config: RLSSMConfig, + participant_col: str = "participant_id", + include: list[dict[str, Any] | Any] | None = None, + p_outlier: float | dict | bmb.Prior | None = 0.05, + lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0), + link_settings: Literal["log_logit"] | None = None, + prior_settings: Literal["safe"] | None = "safe", + extra_namespace: dict[str, Any] | None = None, + missing_data: bool | float = False, + deadline: bool | str = False, + loglik_missing_data: Callable | None = None, + process_initvals: bool = True, + initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], + **kwargs: Any, + ) -> None: + # Validate config (ensures ssm_logp_func is present, etc.) + rlssm_config.validate() + + # Infer panel structure and validate balance BEFORE calling super so any + # error surfaces before the expensive model-build steps. + n_participants, n_trials = validate_balanced_panel(data, participant_col) + + # Store RL-specific state on self BEFORE super().__init__() so that + # _make_model_distribution() (called from super) can access them. + self._rlssm_config = rlssm_config + self._n_participants = n_participants + self._n_trials = n_trials + + # Determine data / param column names for the Op + data_cols: list[str] = ( + list(rlssm_config.response) if rlssm_config.response else ["rt", "response"] + ) + list_params: list[str] = rlssm_config.list_params or [] + extra_fields: list[str] = rlssm_config.extra_fields or [] + + # Build the differentiable pytensor Op from the annotated SSM function. + # This Op supersedes the loglik/loglik_kind workflow: it is passed as + # `loglik` to HSSMBase so Config.validate() is satisfied, and + # _make_model_distribution() uses it directly without any further wrapping. + loglik_op = make_rl_logp_op( + ssm_logp_func=rlssm_config.ssm_logp_func, + n_participants=n_participants, + n_trials=n_trials, + data_cols=data_cols, + list_params=list_params, + extra_fields=extra_fields, + ) + + # Build default_priors from params_default for HSSMBase + default_priors: dict[str, Any] = ( + { + param: default + for param, default in zip( + rlssm_config.list_params, rlssm_config.params_default + ) + } + if rlssm_config.list_params and rlssm_config.params_default + else {} + ) + + # Build a ModelConfig so HSSMBase._build_model_config can apply the + # RLSSM-specific fields (response, list_params, choices, bounds, …). + mc = ModelConfig( + response=(tuple(rlssm_config.response) if rlssm_config.response else None), + list_params=list_params, + choices=(tuple(rlssm_config.choices) if rlssm_config.choices else None), + default_priors=default_priors, + bounds=rlssm_config.bounds or {}, + extra_fields=extra_fields if extra_fields else None, + backend="jax", # RLSSM always uses the JAX backend + ) + + super().__init__( + data=data, + model=rlssm_config.model_name, + choices=list(rlssm_config.choices) if rlssm_config.choices else None, + include=include, + model_config=mc, + # Pass the Op as loglik so Config.validate() is satisfied. + # loglik_kind="approx_differentiable" reflects that the Op is + # differentiable (gradients flow through its VJP). + loglik=loglik_op, + loglik_kind="approx_differentiable", + p_outlier=p_outlier, + lapse=lapse, + link_settings=link_settings, + prior_settings=prior_settings, + extra_namespace=extra_namespace, + missing_data=missing_data, + deadline=deadline, + loglik_missing_data=loglik_missing_data, + process_initvals=process_initvals, + initval_jitter=initval_jitter, + **kwargs, + ) + + def _make_model_distribution(self) -> type[pm.Distribution]: + """Build a pm.Distribution using the pre-built RL log-likelihood Op. + + Unlike :meth:`HSSM._make_model_distribution`, this method does not go + through :func:`~hssm.distribution_utils.make_likelihood_callable`. + Instead it uses ``self.loglik`` directly — the differentiable pytensor + ``Op`` built in :meth:`__init__` from + ``self._rlssm_config.ssm_logp_func``. + + The Op already handles: + - The RL learning rule (computing trial-wise intermediate parameters). + - The per-participant / per-trial data reshaping. + - Gradient computation via its embedded VJP. + + Missing-data network assembly (OPN / CPN) is not yet supported for + RLSSM and logs a warning if requested. + """ + # Warn if a missing-data network was requested; not supported yet. + if self.missing_data_network != MissingDataNetwork.NONE: + _logger.warning( + "Missing-data network assembly (OPN/CPN) is not yet supported " + "for RLSSM. The missing_data_network setting will be ignored." + ) + + if self.missing_data: + _logger.info( + "Re-arranging data to separate missing and observed datapoints. " + "Missing data (rt == %s) will be on top, " + "observed datapoints follow.", + self.missing_data_value, + ) + + # Rearrange data so missing rows come first (no-op when missing_data=False). + self.data = _rearrange_data(self.data) + + # All RLSSM parameters are treated as trialwise: the Op expects arrays of + # length n_total_trials for every parameter, and make_distribution.logp + # broadcasts scalar / (1,)-shaped tensors up to (n_obs,) accordingly. + params_is_trialwise = [ + True for param_name in self.params if param_name != "p_outlier" + ] + + extra_fields_data = ( + None + if not self.extra_fields + else [deepcopy(self.data[field].values) for field in self.extra_fields] + ) + + assert self.list_params is not None, "list_params should be set" + return make_distribution( + rv=self.model_name, + loglik=self.loglik, + list_params=self.list_params, + bounds=self.bounds, + lapse=self.lapse, + extra_fields=extra_fields_data, + params_is_trialwise=params_is_trialwise, + ) diff --git a/src/hssm/rl/utils.py b/src/hssm/rl/utils.py new file mode 100644 index 000000000..485e02ea3 --- /dev/null +++ b/src/hssm/rl/utils.py @@ -0,0 +1,49 @@ +"""Utility functions for reinforcement learning + SSM models.""" + +import pandas as pd + + +def validate_balanced_panel( + data: pd.DataFrame, + participant_col: str = "participant_id", +) -> tuple[int, int]: + """Validate that data forms a balanced panel and return its shape. + + A balanced panel requires every participant to have exactly the same number + of trials (rows in *data*). + + Parameters + ---------- + data : pd.DataFrame + The DataFrame to validate. + participant_col : str, optional + Name of the column that identifies participants. + Defaults to ``"participant_id"``. + + Returns + ------- + tuple[int, int] + ``(n_participants, n_trials)`` where *n_trials* is the number of rows + per participant. + + Raises + ------ + ValueError + If *participant_col* is not present in *data*, or if the panel is + unbalanced (participants have different trial counts). + """ + if participant_col not in data.columns: + raise ValueError( + f"Column '{participant_col}' not found in data. " + "Please provide the correct participant column name via " + "`participant_col`." + ) + + counts = data.groupby(participant_col).size() + if counts.nunique() != 1: + raise ValueError( + "Data must form balanced panels: all participants must have the " + f"same number of trials. Observed trial counts: {dict(counts)}" + ) + + return int(len(counts)), int(counts.iloc[0]) From a6a02385af8e91c977e3f6c772eb26b75f079f90 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 13:39:23 -0500 Subject: [PATCH 03/32] Refactor RLSSM parameter handling and add custom prefix resolution for RL parameters --- src/hssm/rl/rlssm.py | 57 +++++++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index dce195f3b..bf7f129c9 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -142,41 +142,41 @@ def __init__( data_cols: list[str] = ( list(rlssm_config.response) if rlssm_config.response else ["rt", "response"] ) - list_params: list[str] = rlssm_config.list_params or [] + list_params: list[str] = ( + list(rlssm_config.list_params) if rlssm_config.list_params else [] + ) extra_fields: list[str] = rlssm_config.extra_fields or [] # Build the differentiable pytensor Op from the annotated SSM function. # This Op supersedes the loglik/loglik_kind workflow: it is passed as # `loglik` to HSSMBase so Config.validate() is satisfied, and # _make_model_distribution() uses it directly without any further wrapping. + # + # Pass copies of list_params / extra_fields so the closure inside + # make_rl_logp_func captures its own isolated list objects. HSSMBase will + # later append "p_outlier" to self.list_params (which is the SAME list + # object as `list_params` above), and that mutation must NOT be visible to + # the Op's _validate_args_length check at sampling time. loglik_op = make_rl_logp_op( ssm_logp_func=rlssm_config.ssm_logp_func, n_participants=n_participants, n_trials=n_trials, - data_cols=data_cols, - list_params=list_params, - extra_fields=extra_fields, - ) - - # Build default_priors from params_default for HSSMBase - default_priors: dict[str, Any] = ( - { - param: default - for param, default in zip( - rlssm_config.list_params, rlssm_config.params_default - ) - } - if rlssm_config.list_params and rlssm_config.params_default - else {} + data_cols=list(data_cols), + list_params=list(list_params), + extra_fields=list(extra_fields), ) # Build a ModelConfig so HSSMBase._build_model_config can apply the # RLSSM-specific fields (response, list_params, choices, bounds, …). + # default_priors is left as None so that the prior_settings="safe" + # mechanism in HSSMBase assigns sensible priors from bounds. Using + # params_default (scalar floats) here would fix every parameter as a + # constant, which is incorrect. mc = ModelConfig( response=(tuple(rlssm_config.response) if rlssm_config.response else None), list_params=list_params, choices=(tuple(rlssm_config.choices) if rlssm_config.choices else None), - default_priors=default_priors, + default_priors={}, bounds=rlssm_config.bounds or {}, extra_fields=extra_fields if extra_fields else None, backend="jax", # RLSSM always uses the JAX backend @@ -264,3 +264,26 @@ def _make_model_distribution(self) -> type[pm.Distribution]: extra_fields=extra_fields_data, params_is_trialwise=params_is_trialwise, ) + + def _get_prefix(self, name_str: str) -> str: + """Resolve parameter prefix, handling underscore-containing RL param names. + + The base-class implementation splits ``name_str`` on the first ``_`` and + returns that single token as the parameter name. This breaks for RL + parameters whose names contain underscores (e.g. ``rl_alpha``), because + ``"rl_alpha_Intercept".split("_")[0]`` yields ``"rl"``, which is absent + from ``self.params``. + + This override tries successively longer underscore-joined prefixes until + it finds one present in ``self.params``, falling back to the base-class + behaviour when no match is found. + """ + if "p_outlier" in name_str: + return "p_outlier" + if "_" in name_str: + parts = name_str.split("_") + for i in range(len(parts), 0, -1): + candidate = "_".join(parts[:i]) + if candidate in self.params: + return candidate + return super()._get_prefix(name_str) From d8809772d21b1a3041ef36bc2d0ffba2cf238259 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 13:39:30 -0500 Subject: [PATCH 04/32] Add tests for RLSSM class covering initialization, validation, and model structure --- tests/test_rlssm.py | 186 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 tests/test_rlssm.py diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py new file mode 100644 index 000000000..b189fa1ed --- /dev/null +++ b/tests/test_rlssm.py @@ -0,0 +1,186 @@ +"""Tests for the RLSSM class. + +Mirrors the structure of tests/test_hssm.py, covering initialisation, +config validation, param keys, balanced-panel enforcement, the no-lapse +variant, bambi / PyMC model construction, and a sampling smoke test. +""" + +import jax.numpy as jnp +import numpy as np +import pandas as pd +import pytest + +import hssm +from hssm import RLSSM, RLSSMConfig +from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise +from hssm.utils import annotate_function + +hssm.set_floatX("float32", update_jax=True) + +# --------------------------------------------------------------------------- +# Module-level annotated helpers (shared by all tests) +# --------------------------------------------------------------------------- + +# Annotate the RL learning function: maps +# (rl_alpha, scaler, response, feedback) -> v +_compute_v_annotated = annotate_function( + inputs=["rl_alpha", "scaler", "response", "feedback"], + outputs=["v"], +)(compute_v_subject_wise) + + +# Annotated SSM log-likelihood function (simplified for testing). +# It receives a 2-D lan_matrix whose columns correspond to +# [v, a, z, t, theta, rt, response] +# and returns per-trial log-probabilities of shape (n_total_trials, 1). +@annotate_function( + inputs=["v", "a", "z", "t", "theta", "rt", "response"], + outputs=["logp"], + computed={"v": _compute_v_annotated}, +) +def _dummy_ssm_logp(lan_matrix: jnp.ndarray) -> jnp.ndarray: + """Return per-trial log-probabilities (column-sum); structural tests only.""" + # Return 1D (N,) — PyTensor declares the Op output as pt.vector(), so + # gradients arrive as (N,). A (N,1) return causes a VJP shape mismatch. + return jnp.sum(lan_matrix, axis=1) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def rldm_data() -> pd.DataFrame: + """Load the RLDM fixture dataset (balanced panel).""" + raw = np.load("tests/fixtures/rldm_data.npy", allow_pickle=True).item() + return pd.DataFrame(raw["data"]) + + +@pytest.fixture(scope="module") +def rlssm_config() -> RLSSMConfig: + """Minimal but valid RLSSMConfig for the RLDM fixture dataset.""" + return RLSSMConfig( + model_name="rldm_test", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_loglik_kind="blackbox", + list_params=["rl_alpha", "scaler", "a", "theta", "t", "z"], + params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5], + bounds={ + "rl_alpha": (0.0, 1.0), + "scaler": (0.0, 10.0), + "a": (0.1, 3.0), + "theta": (-0.1, 0.1), + "t": (0.001, 1.0), + "z": (0.1, 0.9), + }, + learning_process={"v": _compute_v_annotated}, + response=["rt", "response"], + choices=[0, 1], + extra_fields=["feedback"], + ssm_logp_func=_dummy_ssm_logp, + ) + + +# --------------------------------------------------------------------------- +# Initialisation & config-validation tests +# --------------------------------------------------------------------------- + + +def test_rlssm_init(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """Basic RLSSM initialisation should succeed and return an RLSSM instance.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + assert isinstance(model, RLSSM) + assert model.model_name == "rldm_test" + + +def test_rlssm_panel_attrs(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """_n_participants and _n_trials should match the fixture data structure.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + + n_participants = rldm_data["participant_id"].nunique() + n_trials = len(rldm_data) // n_participants + + assert model._n_participants == n_participants + assert model._n_trials == n_trials + + +def test_rlssm_params_keys(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """model.params should contain exactly list_params + p_outlier.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + expected = set(rlssm_config.list_params) | {"p_outlier"} + assert set(model.params.keys()) == expected + + +def test_rlssm_unbalanced_raises( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """Dropping one row should make the panel unbalanced → ValueError.""" + unbalanced = rldm_data.iloc[:-1].copy() + with pytest.raises(ValueError, match="balanced panels"): + RLSSM(data=unbalanced, rlssm_config=rlssm_config) + + +def test_rlssm_missing_ssm_logp_func_raises( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """RLSSMConfig without ssm_logp_func should raise ValueError on init.""" + bad_config = RLSSMConfig( + model_name="rldm_bad", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_loglik_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=list(rlssm_config.extra_fields), + # ssm_logp_func intentionally omitted → defaults to None + ) + with pytest.raises(ValueError, match="ssm_logp_func"): + RLSSM(data=rldm_data, rlssm_config=bad_config) + + +# --------------------------------------------------------------------------- +# Model-structure tests +# --------------------------------------------------------------------------- + + +def test_rlssm_no_lapse(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """Setting p_outlier=None should remove p_outlier from params.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config, p_outlier=None) + assert "p_outlier" not in model.params + + +def test_rlssm_model_built(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """The bambi model should be built and the computed param 'v' absent from params.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + assert model.model is not None + # rl_alpha is a free (sampled) parameter + assert "rl_alpha" in model.params + # v is computed inside the Op; it must NOT appear as a free parameter + assert "v" not in model.params + + +def test_rlssm_pymc_model(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """pymc_model should be accessible after model construction.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + assert model.pymc_model is not None + + +# --------------------------------------------------------------------------- +# Slow sampling smoke test +# --------------------------------------------------------------------------- + + +@pytest.mark.slow +def test_rlssm_sample_smoke(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """Minimal sampling run should return an InferenceData object.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + trace = model.sample(draws=2, tune=2, chains=1, cores=1) + assert trace is not None From bef8d6c0f40ac3fbad1fdc1df98d3b1fc5a246f2 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 13:50:35 -0500 Subject: [PATCH 05/32] Refactor loglik handling in RLSSM to improve type safety with casting --- src/hssm/rl/rlssm.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index bf7f129c9..7aea9bef6 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -18,12 +18,15 @@ import logging from copy import deepcopy -from typing import Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, Literal, cast import bambi as bmb import pandas as pd import pymc as pm +if TYPE_CHECKING: + from pytensor.graph import Op + from hssm.config import ModelConfig, RLSSMConfig from hssm.defaults import ( INITVAL_JITTER_SETTINGS, @@ -255,9 +258,13 @@ def _make_model_distribution(self) -> type[pm.Distribution]: ) assert self.list_params is not None, "list_params should be set" + # self.loglik was set to the pytensor Op built in __init__; cast to + # narrow the inherited union type so make_distribution's type-checker + # accepts it without a runtime penalty. + loglik_op = cast("Callable[..., Any] | Op", self.loglik) return make_distribution( rv=self.model_name, - loglik=self.loglik, + loglik=loglik_op, list_params=self.list_params, bounds=self.bounds, lapse=self.lapse, From 3981ef660199d226455b6a9b000f5e6455e65cd9 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 14:29:58 -0500 Subject: [PATCH 06/32] Add NaN value check for participant column in validate_balanced_panel function --- src/hssm/rl/utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/hssm/rl/utils.py b/src/hssm/rl/utils.py index 485e02ea3..bbedaa7a8 100644 --- a/src/hssm/rl/utils.py +++ b/src/hssm/rl/utils.py @@ -39,6 +39,13 @@ def validate_balanced_panel( "`participant_col`." ) + n_null = data[participant_col].isna().sum() + if n_null > 0: + raise ValueError( + f"Column '{participant_col}' contains {n_null} NaN value(s). " + "All rows must have a valid participant identifier." + ) + counts = data.groupby(participant_col).size() if counts.nunique() != 1: raise ValueError( From d84a800d154e49847cc423d69bb451c5b8b6adfd Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 14:30:10 -0500 Subject: [PATCH 07/32] Add validation for ssm_logp_func in RLSSMConfig to ensure it is callable and has required attributes --- src/hssm/config.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/hssm/config.py b/src/hssm/config.py index 64e33e932..c19b66a1b 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -371,6 +371,27 @@ def validate(self) -> None: "Please provide `ssm_logp_func`: the fully annotated JAX SSM " "log-likelihood function required by `make_rl_logp_op`." ) + if not callable(self.ssm_logp_func): + raise ValueError( + "`ssm_logp_func` must be a callable, " + f"but got {type(self.ssm_logp_func)!r}." + ) + missing_attrs = [ + attr + for attr in ("inputs", "outputs", "computed") + if not hasattr(self.ssm_logp_func, attr) + ] + if missing_attrs: + raise ValueError( + "`ssm_logp_func` must be decorated with `@annotate_function` " + "so that it carries the attributes required by `make_rl_logp_op`. " + f"Missing attribute(s): {missing_attrs}. " + "Decorate the function like:\n\n" + " @annotate_function(\n" + " inputs=[...], outputs=[...], computed={...}\n" + " )\n" + " def my_ssm_logp(lan_matrix): ..." + ) # Validate parameter defaults consistency if self.params_default and self.list_params: From 15ad6e26bbdfae71692674c7ebc9ad24daf1318d Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 14:32:06 -0500 Subject: [PATCH 08/32] Add exclude rules for ruff and mypy hooks to skip tests directory --- .pre-commit-config.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 27c3beb6b..8d45ea931 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,9 +9,12 @@ repos: hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] + exclude: ^tests/ - id: ruff-format + exclude: ^tests/ - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.19.1 # Use the sha / tag you want to point at hooks: - id: mypy args: [--no-strict-optional, --ignore-missing-imports] + exclude: ^tests/ From 262ec07c64a1054dc5efa707b5555b54761a2815 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 14:32:26 -0500 Subject: [PATCH 09/32] Add validation tests for ssm_logp_func in RLSSMConfig to ensure it is callable and properly annotated --- tests/test_rlssm_config.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 9ae8257c5..eecf4105e 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -1,8 +1,8 @@ import pytest import hssm -from hssm.config import Config, RLSSMConfig -from hssm.config import ModelConfig +from hssm.config import Config, ModelConfig, RLSSMConfig +from hssm.utils import annotate_function # Define constants for repeated data structures DEFAULT_RESPONSE = ("rt", "response") @@ -51,6 +51,7 @@ def create_config_dict( # region fixtures and helpers @pytest.fixture def valid_rlssmconfig_kwargs(): + @annotate_function(inputs=["v", "rt", "response"], outputs=["logp"], computed={}) def _dummy_ssm_logp_func(x): return x @@ -206,18 +207,12 @@ def test_validate_success(self, valid_rlssmconfig_kwargs): config = RLSSMConfig(**valid_rlssmconfig_kwargs) config.validate() - def test_validate_params_default_mismatch(self): + def test_validate_params_default_mismatch(self, valid_rlssmconfig_kwargs): config = RLSSMConfig( - model_name="test_model", - list_params=["alpha", "beta"], - params_default=[0.5], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ssm_logp_func=lambda x: x, + **{ + **valid_rlssmconfig_kwargs, + "params_default": [0.5], # length 1, but list_params has 2 entries + } ) with pytest.raises( ValueError, @@ -225,6 +220,21 @@ def test_validate_params_default_mismatch(self): ): config.validate() + def test_validate_ssm_logp_func_not_callable(self, valid_rlssmconfig_kwargs): + config = RLSSMConfig(**valid_rlssmconfig_kwargs) + config.ssm_logp_func = "not_a_callable" + with pytest.raises(ValueError, match="must be a callable"): + config.validate() + + def test_validate_ssm_logp_func_missing_annotations(self, valid_rlssmconfig_kwargs): + config = RLSSMConfig(**valid_rlssmconfig_kwargs) + # Replace with a plain callable that lacks @annotate_function attributes + config.ssm_logp_func = lambda x: x + with pytest.raises( + ValueError, match="must be decorated with `@annotate_function`" + ): + config.validate() + class TestRLSSMConfigDefaults: @pytest.mark.parametrize( From 381275a5d0cd7780b175d99e29df0353b54724e2 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 14:32:39 -0500 Subject: [PATCH 10/32] Add tests for NaN participant_id and unannotated ssm_logp_func in RLSSM --- tests/test_rlssm.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index b189fa1ed..1c0cc158e 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -123,6 +123,16 @@ def test_rlssm_unbalanced_raises( RLSSM(data=unbalanced, rlssm_config=rlssm_config) +def test_rlssm_nan_participant_id_raises( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """NaN in participant_id column should raise ValueError before groupby silently drops rows.""" + nan_data = rldm_data.copy() + nan_data.loc[nan_data.index[0], "participant_id"] = float("nan") + with pytest.raises(ValueError, match="NaN"): + RLSSM(data=nan_data, rlssm_config=rlssm_config) + + def test_rlssm_missing_ssm_logp_func_raises( rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig ) -> None: @@ -146,6 +156,29 @@ def test_rlssm_missing_ssm_logp_func_raises( RLSSM(data=rldm_data, rlssm_config=bad_config) +def test_rlssm_unannotated_ssm_logp_func_raises( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """A plain callable without @annotate_function attrs should raise ValueError.""" + bad_config = RLSSMConfig( + model_name="rldm_bad", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_loglik_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=list(rlssm_config.extra_fields), + ssm_logp_func=lambda x: x, # callable but no .inputs/.outputs/.computed + ) + with pytest.raises(ValueError, match="annotate_function"): + RLSSM(data=rldm_data, rlssm_config=bad_config) + + # --------------------------------------------------------------------------- # Model-structure tests # --------------------------------------------------------------------------- From 0e9ba42d26ddd82428b0e2b19b6944e874bcc6f5 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 15:13:28 -0500 Subject: [PATCH 11/32] Reject missing data and deadline handling in RLSSM initialization to preserve trial sequence integrity --- src/hssm/rl/rlssm.py | 44 +++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 7aea9bef6..d131e8efc 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -30,12 +30,10 @@ from hssm.config import ModelConfig, RLSSMConfig from hssm.defaults import ( INITVAL_JITTER_SETTINGS, - MissingDataNetwork, ) from hssm.distribution_utils import make_distribution from hssm.rl.likelihoods.builder import make_rl_logp_op from hssm.rl.utils import validate_balanced_panel -from hssm.utils import _rearrange_data from ..base import HSSMBase @@ -131,6 +129,27 @@ def __init__( # Validate config (ensures ssm_logp_func is present, etc.) rlssm_config.validate() + # RLSSM reshapes rows into (n_participants, n_trials, ...) by position, + # so _rearrange_data (which moves missing/deadline rows to the front) + # would scramble per-participant trial sequences and corrupt RL dynamics. + # Raise early so the user gets a clear message before model construction. + if missing_data is not False: + raise ValueError( + "RLSSM does not support `missing_data` handling. " + "The RL log-likelihood Op relies on strict row order to recover " + "per-participant trial sequences; rearranging rows for missing RT " + "values would corrupt the RL learning dynamics. " + "Please remove missing trials from the data before passing it to RLSSM." + ) + if deadline is not False: + raise ValueError( + "RLSSM does not support `deadline` handling. " + "The RL log-likelihood Op relies on strict row order to recover " + "per-participant trial sequences; rearranging rows for deadline " + "trials would corrupt the RL learning dynamics. Please remove " + "deadline trials from the data before passing it to RLSSM." + ) + # Infer panel structure and validate balance BEFORE calling super so any # error surfaces before the expensive model-build steps. n_participants, n_trials = validate_balanced_panel(data, participant_col) @@ -224,26 +243,9 @@ def _make_model_distribution(self) -> type[pm.Distribution]: - Gradient computation via its embedded VJP. Missing-data network assembly (OPN / CPN) is not yet supported for - RLSSM and logs a warning if requested. + RLSSM and ``missing_data`` / ``deadline`` are rejected in ``__init__`` + before this method is ever reached. """ - # Warn if a missing-data network was requested; not supported yet. - if self.missing_data_network != MissingDataNetwork.NONE: - _logger.warning( - "Missing-data network assembly (OPN/CPN) is not yet supported " - "for RLSSM. The missing_data_network setting will be ignored." - ) - - if self.missing_data: - _logger.info( - "Re-arranging data to separate missing and observed datapoints. " - "Missing data (rt == %s) will be on top, " - "observed datapoints follow.", - self.missing_data_value, - ) - - # Rearrange data so missing rows come first (no-op when missing_data=False). - self.data = _rearrange_data(self.data) - # All RLSSM parameters are treated as trialwise: the Op expects arrays of # length n_total_trials for every parameter, and make_distribution.logp # broadcasts scalar / (1,)-shaped tensors up to (n_obs,) accordingly. From 4f28c68c85df6402f0119ef68a3f001a1e085929 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 15:13:37 -0500 Subject: [PATCH 12/32] Add tests to validate error handling for missing data and deadline in RLSSM initialization --- tests/test_rlssm.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index 1c0cc158e..22f91b919 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -179,6 +179,22 @@ def test_rlssm_unannotated_ssm_logp_func_raises( RLSSM(data=rldm_data, rlssm_config=bad_config) +def test_rlssm_missing_data_raises( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """Passing missing_data!=False should raise ValueError with 'missing_data' in msg.""" + with pytest.raises(ValueError, match="missing_data"): + RLSSM(data=rldm_data, rlssm_config=rlssm_config, missing_data=True) + + +def test_rlssm_deadline_raises( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """Passing deadline!=False should raise ValueError with 'deadline' in msg.""" + with pytest.raises(ValueError, match="deadline"): + RLSSM(data=rldm_data, rlssm_config=rlssm_config, deadline=True) + + # --------------------------------------------------------------------------- # Model-structure tests # --------------------------------------------------------------------------- From 5e9f566770d86c9c01644d85baca45caba71e793 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 15:25:21 -0500 Subject: [PATCH 13/32] Refactor path handling for loading RLDM fixture dataset in tests --- tests/test_rlssm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index 22f91b919..ee929625d 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -5,6 +5,8 @@ variant, bambi / PyMC model construction, and a sampling smoke test. """ +from pathlib import Path + import jax.numpy as jnp import numpy as np import pandas as pd @@ -53,7 +55,9 @@ def _dummy_ssm_logp(lan_matrix: jnp.ndarray) -> jnp.ndarray: @pytest.fixture(scope="module") def rldm_data() -> pd.DataFrame: """Load the RLDM fixture dataset (balanced panel).""" - raw = np.load("tests/fixtures/rldm_data.npy", allow_pickle=True).item() + raw = np.load( + Path(__file__).parent / "fixtures" / "rldm_data.npy", allow_pickle=True + ).item() return pd.DataFrame(raw["data"]) From 67ac2ce88d143b55ee235c3093a2949a4590c3a2 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 15:45:48 -0500 Subject: [PATCH 14/32] Add fixture to set floatX to float32 for module tests --- tests/test_rlssm.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index ee929625d..c067854f0 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -5,11 +5,13 @@ variant, bambi / PyMC model construction, and a sampling smoke test. """ +from collections.abc import Generator from pathlib import Path import jax.numpy as jnp import numpy as np import pandas as pd +import pytensor import pytest import hssm @@ -17,8 +19,6 @@ from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise from hssm.utils import annotate_function -hssm.set_floatX("float32", update_jax=True) - # --------------------------------------------------------------------------- # Module-level annotated helpers (shared by all tests) # --------------------------------------------------------------------------- @@ -52,6 +52,17 @@ def _dummy_ssm_logp(lan_matrix: jnp.ndarray) -> jnp.ndarray: # --------------------------------------------------------------------------- +@pytest.fixture(scope="module", autouse=True) +def _set_floatx_float32() -> Generator[None, None, None]: + """Ensure float32 is used for this module's tests, then restore previous setting.""" + prev_floatx = pytensor.config.floatX + hssm.set_floatX("float32", update_jax=True) + try: + yield + finally: + hssm.set_floatX(prev_floatx, update_jax=True) + + @pytest.fixture(scope="module") def rldm_data() -> pd.DataFrame: """Load the RLDM fixture dataset (balanced panel).""" From e1c05dfeb5782d843dc91265e9b4fa69637aca12 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 15:54:10 -0500 Subject: [PATCH 15/32] Ensure params_is_trialwise aligns with list_params in RLSSM initialization --- src/hssm/rl/rlssm.py | 12 ++++++------ tests/test_rlssm.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index d131e8efc..88e9299c7 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -246,12 +246,12 @@ def _make_model_distribution(self) -> type[pm.Distribution]: RLSSM and ``missing_data`` / ``deadline`` are rejected in ``__init__`` before this method is ever reached. """ - # All RLSSM parameters are treated as trialwise: the Op expects arrays of - # length n_total_trials for every parameter, and make_distribution.logp - # broadcasts scalar / (1,)-shaped tensors up to (n_obs,) accordingly. - params_is_trialwise = [ - True for param_name in self.params if param_name != "p_outlier" - ] + # Build params_is_trialwise in the same order as self.list_params so the + # length always matches the list_params= argument passed to make_distribution. + # p_outlier is a scalar mixture weight (not trialwise); every other RLSSM + # parameter is trialwise (the Op receives one value per trial). + assert self.list_params is not None, "list_params should be set by HSSMBase" + params_is_trialwise = [name != "p_outlier" for name in self.list_params] extra_fields_data = ( None diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index c067854f0..b6fd6f389 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -215,6 +215,21 @@ def test_rlssm_deadline_raises( # --------------------------------------------------------------------------- +def test_rlssm_params_is_trialwise_aligned( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """params_is_trialwise must align with list_params (same length, p_outlier=False).""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + assert model.list_params is not None + params_is_trialwise = [name != "p_outlier" for name in model.list_params] + assert len(params_is_trialwise) == len(model.list_params) + for name, is_tw in zip(model.list_params, params_is_trialwise): + if name == "p_outlier": + assert not is_tw, "p_outlier must be non-trialwise" + else: + assert is_tw, f"{name} must be trialwise" + + def test_rlssm_no_lapse(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: """Setting p_outlier=None should remove p_outlier from params.""" model = RLSSM(data=rldm_data, rlssm_config=rlssm_config, p_outlier=None) From 564232b73ef516839ae55648ae0872b95e6d8b68 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 16:18:45 -0500 Subject: [PATCH 16/32] Clarify comments on default_priors in ModelConfig and remove unnecessary assertion for list_params --- src/hssm/rl/rlssm.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 88e9299c7..fe3571a8d 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -190,10 +190,10 @@ def __init__( # Build a ModelConfig so HSSMBase._build_model_config can apply the # RLSSM-specific fields (response, list_params, choices, bounds, …). - # default_priors is left as None so that the prior_settings="safe" - # mechanism in HSSMBase assigns sensible priors from bounds. Using - # params_default (scalar floats) here would fix every parameter as a - # constant, which is incorrect. + # default_priors is an empty dict (no parameter-specific priors pre-set) + # so that the prior_settings="safe" mechanism in HSSMBase assigns + # sensible priors from bounds. Populating it with params_default scalar + # floats would fix every parameter as a constant, which is incorrect. mc = ModelConfig( response=(tuple(rlssm_config.response) if rlssm_config.response else None), list_params=list_params, @@ -259,7 +259,6 @@ def _make_model_distribution(self) -> type[pm.Distribution]: else [deepcopy(self.data[field].values) for field in self.extra_fields] ) - assert self.list_params is not None, "list_params should be set" # self.loglik was set to the pytensor Op built in __init__; cast to # narrow the inherited union type so make_distribution's type-checker # accepts it without a runtime penalty. From bafc0372c4ff2e3ba9dc5dac93d08b7fb9e0d0a2 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 16:23:58 -0500 Subject: [PATCH 17/32] Update RLSSM to use to_numpy(copy=True) for extra_fields and add test for independent copies --- src/hssm/rl/rlssm.py | 3 +-- tests/test_rlssm.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index fe3571a8d..7d750708e 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -17,7 +17,6 @@ """ import logging -from copy import deepcopy from typing import TYPE_CHECKING, Any, Callable, Literal, cast import bambi as bmb @@ -256,7 +255,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]: extra_fields_data = ( None if not self.extra_fields - else [deepcopy(self.data[field].values) for field in self.extra_fields] + else [self.data[field].to_numpy(copy=True) for field in self.extra_fields] ) # self.loglik was set to the pytensor Op built in __init__; cast to diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index b6fd6f389..dcee84e36 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -246,6 +246,39 @@ def test_rlssm_model_built(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) - assert "v" not in model.params +def test_rlssm_extra_fields_are_copies( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """extra_fields passed to make_distribution must be independent numpy copies. + + to_numpy(copy=True) should return a new buffer; if it returned a view, + in-place mutations of the DataFrame would silently corrupt the distribution. + """ + from unittest.mock import patch + + from hssm.distribution_utils import make_distribution as real_make_distribution + + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + captured: dict = {} + + def capturing_make_distribution(*args, **kwargs): + captured["extra_fields"] = kwargs.get("extra_fields") + return real_make_distribution(*args, **kwargs) + + with patch( + "hssm.rl.rlssm.make_distribution", side_effect=capturing_make_distribution + ): + model._make_model_distribution() + + assert captured.get("extra_fields") is not None + for field_name, arr in zip(rlssm_config.extra_fields, captured["extra_fields"]): + original = model.data[field_name].to_numpy() + assert not np.shares_memory(arr, original), ( + f"extra_fields['{field_name}'] shares memory with the DataFrame — " + "it is a view, not a copy" + ) + + def test_rlssm_pymc_model(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: """pymc_model should be accessible after model construction.""" model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) From ba358a497c52690c5bf197a99f0ff1c1ef703b75 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 16:38:27 -0500 Subject: [PATCH 18/32] Refactor parameter name resolution in RLSSM to handle underscores correctly and improve substring checks --- src/hssm/rl/rlssm.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 7d750708e..226a8fab8 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -276,17 +276,17 @@ def _get_prefix(self, name_str: str) -> str: """Resolve parameter prefix, handling underscore-containing RL param names. The base-class implementation splits ``name_str`` on the first ``_`` and - returns that single token as the parameter name. This breaks for RL - parameters whose names contain underscores (e.g. ``rl_alpha``), because - ``"rl_alpha_Intercept".split("_")[0]`` yields ``"rl"``, which is absent - from ``self.params``. + returns that single token (e.g. ``"rl_alpha_Intercept" → "rl"``), which + breaks for RL parameters whose names contain underscores. It also uses a + substring check (``"p_outlier" in name_str``) for the lapse parameter, + which would misfire for any parameter whose name merely *contains* that + substring. - This override tries successively longer underscore-joined prefixes until - it finds one present in ``self.params``, falling back to the base-class - behaviour when no match is found. + This override replaces both heuristics with a single longest-prefix-first + token search: split on ``_``, then try joining 1…N tokens (longest first) + until a candidate is found in ``self.params``. This is both correct for + multi-token RL param names and collision-free for ``p_outlier``. """ - if "p_outlier" in name_str: - return "p_outlier" if "_" in name_str: parts = name_str.split("_") for i in range(len(parts), 0, -1): From 0bfa755bffd40d3407fa7645036c726535365323 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 16:38:43 -0500 Subject: [PATCH 19/32] Add test for _get_prefix method in RLSSM to ensure token-based matching --- tests/test_rlssm.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index dcee84e36..832a28c4e 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -230,6 +230,19 @@ def test_rlssm_params_is_trialwise_aligned( assert is_tw, f"{name} must be trialwise" +def test_rlssm_get_prefix(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """_get_prefix must use token-based matching, not substring search. + + - 'rl_alpha_Intercept' → 'rl_alpha' (underscore-containing RL param) + - 'p_outlier_log__' → 'p_outlier' (lapse param via token loop, not substring) + - 'a_Intercept' → 'a' (single-token standard param) + """ + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + assert model._get_prefix("rl_alpha_Intercept") == "rl_alpha" + assert model._get_prefix("p_outlier_log__") == "p_outlier" + assert model._get_prefix("a_Intercept") == "a" + + def test_rlssm_no_lapse(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: """Setting p_outlier=None should remove p_outlier from params.""" model = RLSSM(data=rldm_data, rlssm_config=rlssm_config, p_outlier=None) From 5b8a16a6597a4a4a1b3188612d17e59292a4face Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 2 Mar 2026 16:43:33 -0500 Subject: [PATCH 20/32] Refactor RLSSMConfig.from_rlssm_dict to remove model_name parameter and update tests accordingly --- src/hssm/config.py | 5 ++--- tests/test_rlssm_config.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index c19b66a1b..bb91224da 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -300,16 +300,15 @@ def n_extra_fields(self) -> int | None: return len(self.extra_fields) if self.extra_fields else None @classmethod - def from_rlssm_dict(cls, model_name: str, config_dict: dict[str, Any]): + def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> "RLSSMConfig": """ Create RLSSMConfig from a configuration dictionary. Parameters ---------- - model_name : str - The name of the RLSSM model. config_dict : dict[str, Any] Dictionary containing model configuration. Expected keys: + - model_name: Model identifier (required) - description: Model description (optional) - list_params: List of parameter names (required) - extra_fields: List of extra field names from data (required) diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index eecf4105e..96dd8b593 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -154,7 +154,7 @@ def test_from_rlssm_dict_cases( expected_choices, expected_learning_process, ): - config = RLSSMConfig.from_rlssm_dict(model_name, config_dict) + config = RLSSMConfig.from_rlssm_dict(config_dict) assert config.model_name == expected_model_name assert config.params_default == expected_params_default assert config.bounds == expected_bounds @@ -484,7 +484,7 @@ def test_from_rlssm_dict_missing_required(self): with pytest.raises( ValueError, match="decision_process_loglik_kind must be provided" ): - RLSSMConfig.from_rlssm_dict("test_model", config_dict) + RLSSMConfig.from_rlssm_dict(config_dict) def test_missing_decision_process_loglik_kind(self): with pytest.raises(TypeError): @@ -512,7 +512,7 @@ def test_missing_decision_process_loglik_kind(self): with pytest.raises( ValueError, match="decision_process_loglik_kind must be provided" ): - RLSSMConfig.from_rlssm_dict("test_model", config_dict) + RLSSMConfig.from_rlssm_dict(config_dict) def test_with_modelconfig_decision_process(self): decision_config = ModelConfig( From f69f2b6968c2b3a52d92c11c0675b41831c6058d Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 3 Mar 2026 09:42:01 -0500 Subject: [PATCH 21/32] Fix comment in test_rlssm.py to clarify output shape of log-likelihood function --- tests/test_rlssm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index 832a28c4e..61937f168 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -34,7 +34,7 @@ # Annotated SSM log-likelihood function (simplified for testing). # It receives a 2-D lan_matrix whose columns correspond to # [v, a, z, t, theta, rt, response] -# and returns per-trial log-probabilities of shape (n_total_trials, 1). +# and returns per-trial log-probabilities of shape (n_total_trials,). @annotate_function( inputs=["v", "a", "z", "t", "theta", "rt", "response"], outputs=["logp"], From bad943ddb5a783245331acb97747371475c39687 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 3 Mar 2026 09:44:46 -0500 Subject: [PATCH 22/32] Update RLSSMConfig documentation to mark description as required --- src/hssm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index bb91224da..e8f5b6e74 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -309,7 +309,7 @@ def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> "RLSSMConfig": config_dict : dict[str, Any] Dictionary containing model configuration. Expected keys: - model_name: Model identifier (required) - - description: Model description (optional) + - description: Model description (required) - list_params: List of parameter names (required) - extra_fields: List of extra field names from data (required) - params_default: Default parameter values (required) From 241aad297d384188b7812d1f11511a3786ca15d3 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 3 Mar 2026 09:50:36 -0500 Subject: [PATCH 23/32] Add ssm_logp_func to RLSSM_REQUIRED_FIELDS and update RLSSMConfig initialization --- src/hssm/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index e8f5b6e74..e01c1a552 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -38,6 +38,7 @@ "decision_process_loglik_kind", "learning_process_loglik_kind", "extra_fields", + "ssm_logp_func", ) ParamSpec = Union[float, dict[str, Any], Prior, None] @@ -341,7 +342,7 @@ def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> "RLSSMConfig": params_default=config_dict["params_default"], decision_process=config_dict["decision_process"], learning_process=config_dict["learning_process"], - ssm_logp_func=config_dict.get("ssm_logp_func"), + ssm_logp_func=config_dict["ssm_logp_func"], bounds=config_dict.get("bounds", {}), response=config_dict["response"], choices=config_dict["choices"], From ca3816de39c5b4aac7da562f50c6b77398fff706 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 3 Mar 2026 09:52:37 -0500 Subject: [PATCH 24/32] Add dummy ssm_logp_func to tests and validate its presence in RLSSMConfig --- tests/test_rlssm_config.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 96dd8b593..143ff1176 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -16,6 +16,13 @@ } +# Module-level annotated dummy used wherever from_rlssm_dict needs a valid +# ssm_logp_func but the test is not about ssm_logp_func itself. +@annotate_function(inputs=["v", "rt", "response"], outputs=["logp"], computed={}) +def _module_dummy_ssm_logp(x): + return x + + # Helper function to create a config dictionary def create_config_dict( model_name, @@ -29,6 +36,7 @@ def create_config_dict( decision_process="ddm", decision_process_loglik_kind="analytical", learning_process_loglik_kind="blackbox", + ssm_logp_func=_module_dummy_ssm_logp, ): return dict( model_name=model_name, @@ -44,6 +52,7 @@ def create_config_dict( decision_process=decision_process, decision_process_loglik_kind=decision_process_loglik_kind, learning_process_loglik_kind=learning_process_loglik_kind, + ssm_logp_func=ssm_logp_func, data={}, ) @@ -480,12 +489,34 @@ def test_from_rlssm_dict_missing_required(self): "bounds": {}, "data": {}, "extra_fields": [], + "ssm_logp_func": _module_dummy_ssm_logp, } with pytest.raises( ValueError, match="decision_process_loglik_kind must be provided" ): RLSSMConfig.from_rlssm_dict(config_dict) + def test_from_rlssm_dict_missing_ssm_logp_func(self): + # Should raise ValueError at construction time if ssm_logp_func is missing + config_dict = { + "model_name": "test_model", + "name": "test_model", + "list_params": ["alpha"], + "params_default": [0.0], + "decision_process": "ddm", + "learning_process": {}, + "learning_process_loglik_kind": "blackbox", + "decision_process_loglik_kind": "analytical", + "response": ["rt", "response"], + "choices": [0, 1], + "description": "desc", + "bounds": {}, + "data": {}, + "extra_fields": [], + } + with pytest.raises(ValueError, match="ssm_logp_func must be provided"): + RLSSMConfig.from_rlssm_dict(config_dict) + def test_missing_decision_process_loglik_kind(self): with pytest.raises(TypeError): RLSSMConfig( @@ -508,6 +539,7 @@ def test_missing_decision_process_loglik_kind(self): "response": ["rt", "response"], "choices": [0, 1], "extra_fields": [], + "ssm_logp_func": _module_dummy_ssm_logp, } with pytest.raises( ValueError, match="decision_process_loglik_kind must be provided" From 827025c0463fa8ff32c538722bd1a6f4710fe465 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 3 Mar 2026 10:06:47 -0500 Subject: [PATCH 25/32] Remove unused logging import from rlssm.py --- src/hssm/rl/rlssm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 226a8fab8..c79c5b592 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -16,7 +16,6 @@ standard ``loglik`` / ``loglik_kind`` wrapping pipeline. """ -import logging from typing import TYPE_CHECKING, Any, Callable, Literal, cast import bambi as bmb @@ -36,8 +35,6 @@ from ..base import HSSMBase -_logger = logging.getLogger("hssm") - class RLSSM(HSSMBase): """Reinforcement Learning Sequential Sampling Model. From 292d6f0422c8aa0a96522aad97c8454ccf90363b Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 3 Mar 2026 10:08:39 -0500 Subject: [PATCH 26/32] Remove redundant exclude rule for ruff-format in pre-commit configuration --- .pre-commit-config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8d45ea931..5782d5b88 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,6 @@ repos: args: [--fix, --exit-non-zero-on-fix] exclude: ^tests/ - id: ruff-format - exclude: ^tests/ - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.19.1 # Use the sha / tag you want to point at hooks: From 3b1aaf49f50b9d6a14da5f90b4a03a73c6444eb0 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 4 Mar 2026 12:05:00 -0500 Subject: [PATCH 27/32] Add to_model_config method to RLSSMConfig for ModelConfig conversion --- src/hssm/config.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/hssm/config.py b/src/hssm/config.py index e01c1a552..5c5fd13f9 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -482,6 +482,27 @@ def to_config(self) -> Config: loglik=self.loglik, ) + def to_model_config(self) -> "ModelConfig": + """Build a :class:`ModelConfig` from this :class:`RLSSMConfig`. + + All fields are sourced from ``self``; the backend is fixed to ``"jax"`` + because RLSSM exclusively uses the JAX backend. + + ``default_priors`` is intentionally left empty so the + ``prior_settings="safe"`` mechanism in :class:`~hssm.base.HSSMBase` + assigns sensible priors from bounds rather than fixing every parameter + to a constant scalar. + """ + return ModelConfig( + response=tuple(self.response), # type: ignore[arg-type] + list_params=list(self.list_params), # type: ignore[arg-type] + choices=tuple(self.choices), # type: ignore[arg-type] + default_priors={}, + bounds=self.bounds, + extra_fields=self.extra_fields, + backend="jax", + ) + @dataclass class ModelConfig: From 0678c45720fd7498a5d7a6ffff5b686f42750423 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 4 Mar 2026 12:05:10 -0500 Subject: [PATCH 28/32] Refactor RLSSM to delegate ModelConfig construction to RLSSMConfig and simplify Op parameter handling --- src/hssm/rl/rlssm.py | 45 +++++++++++--------------------------------- 1 file changed, 11 insertions(+), 34 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index c79c5b592..11bab1bdc 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from pytensor.graph import Op -from hssm.config import ModelConfig, RLSSMConfig +from hssm.config import RLSSMConfig from hssm.defaults import ( INITVAL_JITTER_SETTINGS, ) @@ -156,54 +156,31 @@ def __init__( self._n_participants = n_participants self._n_trials = n_trials - # Determine data / param column names for the Op - data_cols: list[str] = ( - list(rlssm_config.response) if rlssm_config.response else ["rt", "response"] - ) - list_params: list[str] = ( - list(rlssm_config.list_params) if rlssm_config.list_params else [] - ) - extra_fields: list[str] = rlssm_config.extra_fields or [] - # Build the differentiable pytensor Op from the annotated SSM function. # This Op supersedes the loglik/loglik_kind workflow: it is passed as # `loglik` to HSSMBase so Config.validate() is satisfied, and # _make_model_distribution() uses it directly without any further wrapping. # - # Pass copies of list_params / extra_fields so the closure inside - # make_rl_logp_func captures its own isolated list objects. HSSMBase will - # later append "p_outlier" to self.list_params (which is the SAME list - # object as `list_params` above), and that mutation must NOT be visible to - # the Op's _validate_args_length check at sampling time. + # Fresh list() copies are passed to make_rl_logp_op so the closure inside + # captures its own isolated list objects. HSSMBase will later append + # "p_outlier" to self.list_params, and that mutation must NOT be visible + # to the Op's _validate_args_length check at sampling time. loglik_op = make_rl_logp_op( ssm_logp_func=rlssm_config.ssm_logp_func, n_participants=n_participants, n_trials=n_trials, - data_cols=list(data_cols), - list_params=list(list_params), - extra_fields=list(extra_fields), + data_cols=list(rlssm_config.response), # type: ignore[arg-type] + list_params=list(rlssm_config.list_params), # type: ignore[arg-type] + extra_fields=list(rlssm_config.extra_fields or []), ) - # Build a ModelConfig so HSSMBase._build_model_config can apply the - # RLSSM-specific fields (response, list_params, choices, bounds, …). - # default_priors is an empty dict (no parameter-specific priors pre-set) - # so that the prior_settings="safe" mechanism in HSSMBase assigns - # sensible priors from bounds. Populating it with params_default scalar - # floats would fix every parameter as a constant, which is incorrect. - mc = ModelConfig( - response=(tuple(rlssm_config.response) if rlssm_config.response else None), - list_params=list_params, - choices=(tuple(rlssm_config.choices) if rlssm_config.choices else None), - default_priors={}, - bounds=rlssm_config.bounds or {}, - extra_fields=extra_fields if extra_fields else None, - backend="jax", # RLSSM always uses the JAX backend - ) + # Delegate ModelConfig construction to RLSSMConfig, which already owns + # all the required fields (response, list_params, choices, bounds, …). + mc = rlssm_config.to_model_config() super().__init__( data=data, model=rlssm_config.model_name, - choices=list(rlssm_config.choices) if rlssm_config.choices else None, include=include, model_config=mc, # Pass the Op as loglik so Config.validate() is satisfied. From 26a933675fedbda29eb6920b4c99a75cf80245ed Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 9 Mar 2026 14:06:18 -0400 Subject: [PATCH 29/32] Integrate Config and RLSSMConfig into HSSM and RLSSM classes for improved configuration handling --- src/hssm/base.py | 5 ++--- src/hssm/config.py | 29 ++++++++++++++++++++++++++++- src/hssm/data_validator.py | 4 ++-- src/hssm/hssm.py | 3 ++- src/hssm/rl/rlssm.py | 2 +- 5 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/hssm/base.py b/src/hssm/base.py index 0aa1dec10..1b4e1db59 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -264,8 +264,6 @@ class HSSMBase(ABC, DataValidatorMixin, MissingDataMixin): The jitter value for the initial values. """ - config_class = Config - def __init__( self, data: pd.DataFrame, @@ -529,7 +527,8 @@ def _build_model_config( A complete Config object with choices and other settings applied. """ # Start with defaults - config = cls.config_class.from_defaults(model, loglik_kind) + # get_config_class is provided by Config/RLSSMConfig mixin through MRO + config = cls.get_config_class().from_defaults(model, loglik_kind) # type: ignore[attr-defined] # Handle user-provided model_config if model_config is not None: diff --git a/src/hssm/config.py b/src/hssm/config.py index 5c5fd13f9..498c53603 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -54,7 +54,7 @@ class BaseModelConfig(ABC): # Data specification response: list[str] | None = field(default_factory=DEFAULT_SSM_OBSERVED_DATA.copy) - choices: tuple[int, ...] | None = DEFAULT_SSM_CHOICES + choices: list[int] | tuple[int, ...] | None = DEFAULT_SSM_CHOICES # Parameter specification list_params: list[str] | None = None @@ -68,6 +68,16 @@ class BaseModelConfig(ABC): # Additional data requirements extra_fields: list[str] | None = None + @classmethod + @abstractmethod + def get_config_class(cls) -> type["BaseModelConfig"]: + """Return the config class for this model type. + + This enables polymorphic config resolution without circular imports. + Each subclass returns itself as the config class. + """ + ... + @abstractmethod def validate(self) -> None: """Validate configuration. Must be implemented by subclasses.""" @@ -92,6 +102,11 @@ def __post_init__(self): if self.loglik_kind is None: raise ValueError("loglik_kind is required for Config") + @classmethod + def get_config_class(cls) -> type["Config"]: + """Return Config as the config class for HSSM models.""" + return Config + @classmethod def from_defaults( cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None @@ -290,6 +305,18 @@ def __post_init__(self): if self.loglik_kind is None: self.loglik_kind = "approx_differentiable" + @classmethod + def get_config_class(cls) -> type["RLSSMConfig"]: + """Return RLSSMConfig as the config class for RLSSM models.""" + return RLSSMConfig + + @classmethod + def from_defaults( + cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None + ) -> Config: + """Return the shared Config defaults (delegated to :class:`Config`).""" + return Config.from_defaults(model_name, loglik_kind) + @property def n_params(self) -> int | None: """Return the number of parameters.""" diff --git a/src/hssm/data_validator.py b/src/hssm/data_validator.py index ae09a43fc..ff48e39aa 100644 --- a/src/hssm/data_validator.py +++ b/src/hssm/data_validator.py @@ -17,7 +17,7 @@ class DataValidatorMixin: This class expects subclasses to define the following attributes: - data: pd.DataFrame - response: list[str] - - choices: list[int] + - choices: list[int] | tuple[int, ...] - n_choices: int - extra_fields: list[str] | None - deadline: bool @@ -30,7 +30,7 @@ def __init__( self, data: pd.DataFrame, response: list[str] | None = ["rt", "response"], - choices: list[int] | None = [0, 1], + choices: list[int] | tuple[int, ...] | None = [0, 1], n_choices: int = 2, extra_fields: list[str] | None = None, deadline: bool = False, diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 8f179d942..45224cd20 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -30,6 +30,7 @@ ) from .base import HSSMBase +from .config import Config if TYPE_CHECKING: from os import PathLike @@ -74,7 +75,7 @@ def __get__(self, instance, owner): # noqa: D105 return self.fget(owner) -class HSSM(HSSMBase): +class HSSM(HSSMBase, Config): """The basic Hierarchical Sequential Sampling Model (HSSM) class. Parameters diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 11bab1bdc..339dd870c 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -36,7 +36,7 @@ from ..base import HSSMBase -class RLSSM(HSSMBase): +class RLSSM(HSSMBase, RLSSMConfig): """Reinforcement Learning Sequential Sampling Model. Combines a reinforcement learning (RL) process with a sequential sampling From cd660da092b4ea9d2cd9fed684495259421d4fe4 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 9 Mar 2026 15:12:59 -0400 Subject: [PATCH 30/32] Update choices type from list to tuple for consistency in BaseModelConfig and DataValidatorMixin --- src/hssm/config.py | 2 +- src/hssm/data_validator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index 498c53603..31415df9a 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -54,7 +54,7 @@ class BaseModelConfig(ABC): # Data specification response: list[str] | None = field(default_factory=DEFAULT_SSM_OBSERVED_DATA.copy) - choices: list[int] | tuple[int, ...] | None = DEFAULT_SSM_CHOICES + choices: tuple[int, ...] | None = DEFAULT_SSM_CHOICES # Parameter specification list_params: list[str] | None = None diff --git a/src/hssm/data_validator.py b/src/hssm/data_validator.py index ff48e39aa..738e85fe3 100644 --- a/src/hssm/data_validator.py +++ b/src/hssm/data_validator.py @@ -30,7 +30,7 @@ def __init__( self, data: pd.DataFrame, response: list[str] | None = ["rt", "response"], - choices: list[int] | tuple[int, ...] | None = [0, 1], + choices: tuple[int, ...] | None = (0, 1), n_choices: int = 2, extra_fields: list[str] | None = None, deadline: bool = False, From db308c6932f6df8b60a84da9236b5bd7869bce44 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 9 Mar 2026 15:33:00 -0400 Subject: [PATCH 31/32] Update choices type from list to tuple in test_constructor for consistency --- tests/test_data_validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_data_validator.py b/tests/test_data_validator.py index ecdc927fa..9efbdc05d 100644 --- a/tests/test_data_validator.py +++ b/tests/test_data_validator.py @@ -66,7 +66,7 @@ def test_constructor(base_data): assert isinstance(dv, DataValidatorMixin) assert dv.data.equals(_base_data()) assert dv.response == ["rt", "response"] - assert dv.choices == [0, 1] + assert dv.choices == (0, 1) assert dv.n_choices == 2 assert dv.extra_fields == ["extra"] assert dv.deadline is True From ef000cdec9bed02157699c8ea7c1a404032bffcb Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 27 Mar 2026 11:26:33 -0400 Subject: [PATCH 32/32] Fix formatting of error messages in TestRLSSMConfigValidation for consistency --- tests/test_rlssm_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 143ff1176..55cbbe244 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -178,9 +178,9 @@ class TestRLSSMConfigValidation: [ ("response", None, "Please provide `response` columns"), ("list_params", None, "Please provide `list_params"), - ("choices", None, "Please provide `choices"), - ("decision_process", None, "Please specify a `decision_process"), - ("ssm_logp_func", None, "Please provide `ssm_logp_func"), + ("choices", None, "Please provide `choices`"), + ("decision_process", None, "Please specify a `decision_process`"), + ("ssm_logp_func", None, "Please provide `ssm_logp_func`"), ], ) def test_validate_missing_fields(