diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 27c3beb6b..5782d5b88 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,9 +9,11 @@ repos: hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] + exclude: ^tests/ - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.19.1 # Use the sha / tag you want to point at hooks: - id: mypy args: [--no-strict-optional, --ignore-missing-imports] + exclude: ^tests/ diff --git a/src/hssm/__init__.py b/src/hssm/__init__.py index 60dd71020..72aa14ba3 100644 --- a/src/hssm/__init__.py +++ b/src/hssm/__init__.py @@ -11,7 +11,7 @@ import logging import sys -from .config import ModelConfig +from .config import ModelConfig, RLSSMConfig from .datasets import load_data from .defaults import show_defaults from .hssm import HSSM @@ -19,6 +19,7 @@ from .param import UserParam as Param from .prior import Prior from .register import register_model +from .rl import RLSSM from .simulator import simulate_data from .utils import check_data_for_rl, set_floatX @@ -31,6 +32,8 @@ __all__ = [ "HSSM", + "RLSSM", + "RLSSMConfig", "Link", "load_data", "ModelConfig", diff --git a/src/hssm/base.py b/src/hssm/base.py index 0aa1dec10..1b4e1db59 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -264,8 +264,6 @@ class HSSMBase(ABC, DataValidatorMixin, MissingDataMixin): The jitter value for the initial values. """ - config_class = Config - def __init__( self, data: pd.DataFrame, @@ -529,7 +527,8 @@ def _build_model_config( A complete Config object with choices and other settings applied. """ # Start with defaults - config = cls.config_class.from_defaults(model, loglik_kind) + # get_config_class is provided by Config/RLSSMConfig mixin through MRO + config = cls.get_config_class().from_defaults(model, loglik_kind) # type: ignore[attr-defined] # Handle user-provided model_config if model_config is not None: diff --git a/src/hssm/config.py b/src/hssm/config.py index 1ae0ea2a0..31415df9a 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -38,6 +38,7 @@ "decision_process_loglik_kind", "learning_process_loglik_kind", "extra_fields", + "ssm_logp_func", ) ParamSpec = Union[float, dict[str, Any], Prior, None] @@ -67,6 +68,16 @@ class BaseModelConfig(ABC): # Additional data requirements extra_fields: list[str] | None = None + @classmethod + @abstractmethod + def get_config_class(cls) -> type["BaseModelConfig"]: + """Return the config class for this model type. + + This enables polymorphic config resolution without circular imports. + Each subclass returns itself as the config class. + """ + ... + @abstractmethod def validate(self) -> None: """Validate configuration. Must be implemented by subclasses.""" @@ -91,6 +102,11 @@ def __post_init__(self): if self.loglik_kind is None: raise ValueError("loglik_kind is required for Config") + @classmethod + def get_config_class(cls) -> type["Config"]: + """Return Config as the config class for HSSM models.""" + return Config + @classmethod def from_defaults( cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None @@ -266,6 +282,12 @@ class RLSSMConfig(BaseModelConfig): This configuration class is designed for models that combine reinforcement learning processes with sequential sampling decision models (RLSSM). + + The ``ssm_logp_func`` field holds the fully annotated JAX SSM log-likelihood + function (an :class:`AnnotatedFunction`) that is passed directly to + ``make_rl_logp_op``. It supersedes the ``loglik`` / ``loglik_kind`` workflow + used by :class:`HSSM`: the Op is built from ``ssm_logp_func`` and therefore + no ``loglik`` callable needs to be provided. """ decision_process_loglik_kind: str = field(kw_only=True) @@ -273,12 +295,28 @@ class RLSSMConfig(BaseModelConfig): params_default: list[float] = field(kw_only=True) decision_process: str | ModelConfig = field(kw_only=True) learning_process: dict[str, Any] = field(kw_only=True) + # The fully annotated SSM log-likelihood function used by make_rl_logp_op. + # Type is Any to avoid a hard dependency on the AnnotatedFunction Protocol at + # import time; validated at runtime in validate(). + ssm_logp_func: Any = field(default=None, kw_only=True) def __post_init__(self): """Set default loglik_kind for RLSSM models if not provided.""" if self.loglik_kind is None: self.loglik_kind = "approx_differentiable" + @classmethod + def get_config_class(cls) -> type["RLSSMConfig"]: + """Return RLSSMConfig as the config class for RLSSM models.""" + return RLSSMConfig + + @classmethod + def from_defaults( + cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None + ) -> Config: + """Return the shared Config defaults (delegated to :class:`Config`).""" + return Config.from_defaults(model_name, loglik_kind) + @property def n_params(self) -> int | None: """Return the number of parameters.""" @@ -290,17 +328,16 @@ def n_extra_fields(self) -> int | None: return len(self.extra_fields) if self.extra_fields else None @classmethod - def from_rlssm_dict(cls, model_name: str, config_dict: dict[str, Any]): + def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> "RLSSMConfig": """ Create RLSSMConfig from a configuration dictionary. Parameters ---------- - model_name : str - The name of the RLSSM model. config_dict : dict[str, Any] Dictionary containing model configuration. Expected keys: - - description: Model description (optional) + - model_name: Model identifier (required) + - description: Model description (required) - list_params: List of parameter names (required) - extra_fields: List of extra field names from data (required) - params_default: Default parameter values (required) @@ -332,6 +369,7 @@ def from_rlssm_dict(cls, model_name: str, config_dict: dict[str, Any]): params_default=config_dict["params_default"], decision_process=config_dict["decision_process"], learning_process=config_dict["learning_process"], + ssm_logp_func=config_dict["ssm_logp_func"], bounds=config_dict.get("bounds", {}), response=config_dict["response"], choices=config_dict["choices"], @@ -355,6 +393,32 @@ def validate(self) -> None: raise ValueError("Please provide `choices` in the configuration.") if self.decision_process is None: raise ValueError("Please specify a `decision_process`.") + if self.ssm_logp_func is None: + raise ValueError( + "Please provide `ssm_logp_func`: the fully annotated JAX SSM " + "log-likelihood function required by `make_rl_logp_op`." + ) + if not callable(self.ssm_logp_func): + raise ValueError( + "`ssm_logp_func` must be a callable, " + f"but got {type(self.ssm_logp_func)!r}." + ) + missing_attrs = [ + attr + for attr in ("inputs", "outputs", "computed") + if not hasattr(self.ssm_logp_func, attr) + ] + if missing_attrs: + raise ValueError( + "`ssm_logp_func` must be decorated with `@annotate_function` " + "so that it carries the attributes required by `make_rl_logp_op`. " + f"Missing attribute(s): {missing_attrs}. " + "Decorate the function like:\n\n" + " @annotate_function(\n" + " inputs=[...], outputs=[...], computed={...}\n" + " )\n" + " def my_ssm_logp(lan_matrix): ..." + ) # Validate parameter defaults consistency if self.params_default and self.list_params: @@ -445,6 +509,27 @@ def to_config(self) -> Config: loglik=self.loglik, ) + def to_model_config(self) -> "ModelConfig": + """Build a :class:`ModelConfig` from this :class:`RLSSMConfig`. + + All fields are sourced from ``self``; the backend is fixed to ``"jax"`` + because RLSSM exclusively uses the JAX backend. + + ``default_priors`` is intentionally left empty so the + ``prior_settings="safe"`` mechanism in :class:`~hssm.base.HSSMBase` + assigns sensible priors from bounds rather than fixing every parameter + to a constant scalar. + """ + return ModelConfig( + response=tuple(self.response), # type: ignore[arg-type] + list_params=list(self.list_params), # type: ignore[arg-type] + choices=tuple(self.choices), # type: ignore[arg-type] + default_priors={}, + bounds=self.bounds, + extra_fields=self.extra_fields, + backend="jax", + ) + @dataclass class ModelConfig: diff --git a/src/hssm/data_validator.py b/src/hssm/data_validator.py index ae09a43fc..738e85fe3 100644 --- a/src/hssm/data_validator.py +++ b/src/hssm/data_validator.py @@ -17,7 +17,7 @@ class DataValidatorMixin: This class expects subclasses to define the following attributes: - data: pd.DataFrame - response: list[str] - - choices: list[int] + - choices: list[int] | tuple[int, ...] - n_choices: int - extra_fields: list[str] | None - deadline: bool @@ -30,7 +30,7 @@ def __init__( self, data: pd.DataFrame, response: list[str] | None = ["rt", "response"], - choices: list[int] | None = [0, 1], + choices: tuple[int, ...] | None = (0, 1), n_choices: int = 2, extra_fields: list[str] | None = None, deadline: bool = False, diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 8f179d942..45224cd20 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -30,6 +30,7 @@ ) from .base import HSSMBase +from .config import Config if TYPE_CHECKING: from os import PathLike @@ -74,7 +75,7 @@ def __get__(self, instance, owner): # noqa: D105 return self.fget(owner) -class HSSM(HSSMBase): +class HSSM(HSSMBase, Config): """The basic Hierarchical Sequential Sampling Model (HSSM) class. Parameters diff --git a/src/hssm/rl/__init__.py b/src/hssm/rl/__init__.py new file mode 100644 index 000000000..09eb646ae --- /dev/null +++ b/src/hssm/rl/__init__.py @@ -0,0 +1,18 @@ +"""Reinforcement learning extensions for HSSM. + +This sub-package provides: + +- :class:`~hssm.rl.rlssm.RLSSM` — the RL + SSM model class. +- :func:`~hssm.rl.utils.validate_balanced_panel` — panel-balance utility. +- :mod:`hssm.rl.likelihoods` — log-likelihood builders + (:func:`~hssm.rl.likelihoods.builder.make_rl_logp_func`, + :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`). +""" + +from .rlssm import RLSSM +from .utils import validate_balanced_panel + +__all__ = [ + "RLSSM", + "validate_balanced_panel", +] diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py new file mode 100644 index 000000000..339dd870c --- /dev/null +++ b/src/hssm/rl/rlssm.py @@ -0,0 +1,270 @@ +"""RLSSM: Reinforcement Learning Sequential Sampling Model. + +This module defines the :class:`RLSSM` class, a subclass of :class:`HSSMBase` +for models that couple a reinforcement learning (RL) learning process with a +sequential sampling decision model (SSM). + +The key difference from :class:`HSSM` is the likelihood: + - ``HSSM`` wraps an analytical / ONNX / blackbox callable via + :func:`~hssm.distribution_utils.make_likelihood_callable`. + - ``RLSSM`` builds a differentiable pytensor ``Op`` directly from an + :class:`~hssm.rl.likelihoods.builder.AnnotatedFunction` via + :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`, which internally + handles the RL learning rule and per-participant trial structure. + This Op is then passed straight to + :func:`~hssm.distribution_utils.make_distribution`, bypassing the + standard ``loglik`` / ``loglik_kind`` wrapping pipeline. +""" + +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +import bambi as bmb +import pandas as pd +import pymc as pm + +if TYPE_CHECKING: + from pytensor.graph import Op + +from hssm.config import RLSSMConfig +from hssm.defaults import ( + INITVAL_JITTER_SETTINGS, +) +from hssm.distribution_utils import make_distribution +from hssm.rl.likelihoods.builder import make_rl_logp_op +from hssm.rl.utils import validate_balanced_panel + +from ..base import HSSMBase + + +class RLSSM(HSSMBase, RLSSMConfig): + """Reinforcement Learning Sequential Sampling Model. + + Combines a reinforcement learning (RL) process with a sequential sampling + model (SSM) inside a single differentiable likelihood. The RL component + computes trial-wise intermediate parameters (e.g., drift rates) from the + learning history, which are then fed into the SSM log-likelihood. + + The likelihood is built via + :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op` from the annotated + SSM function stored in *rlssm_config.ssm_logp_func*. This produces a + differentiable pytensor ``Op`` that is passed directly to + :func:`~hssm.distribution_utils.make_distribution`, superseding the + ``loglik`` / ``loglik_kind`` dispatching used by :class:`~hssm.hssm.HSSM`. + + Parameters + ---------- + data : pd.DataFrame + Trial-level data. Must contain at least the response columns + specified in *rlssm_config* (typically ``"rt"`` and ``"response"``), + a participant identifier column (default ``"participant_id"``), and + any extra fields listed in *rlssm_config.extra_fields*. + The data **must** form a balanced panel: every participant must have + the same number of trials. + rlssm_config : RLSSMConfig + Full configuration for the RLSSM model. Must have ``ssm_logp_func`` + set to the annotated JAX SSM log-likelihood function. + participant_col : str, optional + Name of the column that uniquely identifies participants. + Used to infer ``n_participants`` and ``n_trials`` from *data*. + Defaults to ``"participant_id"``. + include : list, optional + Parameter specifications forwarded to :class:`~hssm.base.HSSMBase`. + p_outlier : float | dict | bmb.Prior | None, optional + Lapse probability specification. Defaults to ``0.05``. + lapse : dict | bmb.Prior | None, optional + Lapse distribution. Defaults to ``Uniform(0, 20)``. + link_settings : Literal["log_logit"] | None, optional + Link-function preset. Defaults to ``None``. + prior_settings : Literal["safe"] | None, optional + Prior preset. Defaults to ``"safe"``. + extra_namespace : dict | None, optional + Extra variables for formula evaluation. Defaults to ``None``. + missing_data : bool | float, optional + Whether to handle missing RT data coded as ``-999.0``. + Defaults to ``False``. + deadline : bool | str, optional + Whether to handle deadline data. Defaults to ``False``. + loglik_missing_data : Callable | None, optional + Custom likelihood for missing observations. Defaults to ``None``. + process_initvals : bool, optional + Whether to post-process initial values. Defaults to ``True``. + initval_jitter : float, optional + Jitter magnitude for initial values. + Defaults to :data:`~hssm.defaults.INITVAL_JITTER_SETTINGS` epsilon. + **kwargs + Additional keyword arguments forwarded to :class:`bmb.Model`. + + Attributes + ---------- + _rlssm_config : RLSSMConfig + The RLSSM configuration object. + _n_participants : int + Number of participants inferred from *data*. + _n_trials : int + Number of trials per participant inferred from *data*. + """ + + def __init__( + self, + data: pd.DataFrame, + rlssm_config: RLSSMConfig, + participant_col: str = "participant_id", + include: list[dict[str, Any] | Any] | None = None, + p_outlier: float | dict | bmb.Prior | None = 0.05, + lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0), + link_settings: Literal["log_logit"] | None = None, + prior_settings: Literal["safe"] | None = "safe", + extra_namespace: dict[str, Any] | None = None, + missing_data: bool | float = False, + deadline: bool | str = False, + loglik_missing_data: Callable | None = None, + process_initvals: bool = True, + initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], + **kwargs: Any, + ) -> None: + # Validate config (ensures ssm_logp_func is present, etc.) + rlssm_config.validate() + + # RLSSM reshapes rows into (n_participants, n_trials, ...) by position, + # so _rearrange_data (which moves missing/deadline rows to the front) + # would scramble per-participant trial sequences and corrupt RL dynamics. + # Raise early so the user gets a clear message before model construction. + if missing_data is not False: + raise ValueError( + "RLSSM does not support `missing_data` handling. " + "The RL log-likelihood Op relies on strict row order to recover " + "per-participant trial sequences; rearranging rows for missing RT " + "values would corrupt the RL learning dynamics. " + "Please remove missing trials from the data before passing it to RLSSM." + ) + if deadline is not False: + raise ValueError( + "RLSSM does not support `deadline` handling. " + "The RL log-likelihood Op relies on strict row order to recover " + "per-participant trial sequences; rearranging rows for deadline " + "trials would corrupt the RL learning dynamics. Please remove " + "deadline trials from the data before passing it to RLSSM." + ) + + # Infer panel structure and validate balance BEFORE calling super so any + # error surfaces before the expensive model-build steps. + n_participants, n_trials = validate_balanced_panel(data, participant_col) + + # Store RL-specific state on self BEFORE super().__init__() so that + # _make_model_distribution() (called from super) can access them. + self._rlssm_config = rlssm_config + self._n_participants = n_participants + self._n_trials = n_trials + + # Build the differentiable pytensor Op from the annotated SSM function. + # This Op supersedes the loglik/loglik_kind workflow: it is passed as + # `loglik` to HSSMBase so Config.validate() is satisfied, and + # _make_model_distribution() uses it directly without any further wrapping. + # + # Fresh list() copies are passed to make_rl_logp_op so the closure inside + # captures its own isolated list objects. HSSMBase will later append + # "p_outlier" to self.list_params, and that mutation must NOT be visible + # to the Op's _validate_args_length check at sampling time. + loglik_op = make_rl_logp_op( + ssm_logp_func=rlssm_config.ssm_logp_func, + n_participants=n_participants, + n_trials=n_trials, + data_cols=list(rlssm_config.response), # type: ignore[arg-type] + list_params=list(rlssm_config.list_params), # type: ignore[arg-type] + extra_fields=list(rlssm_config.extra_fields or []), + ) + + # Delegate ModelConfig construction to RLSSMConfig, which already owns + # all the required fields (response, list_params, choices, bounds, …). + mc = rlssm_config.to_model_config() + + super().__init__( + data=data, + model=rlssm_config.model_name, + include=include, + model_config=mc, + # Pass the Op as loglik so Config.validate() is satisfied. + # loglik_kind="approx_differentiable" reflects that the Op is + # differentiable (gradients flow through its VJP). + loglik=loglik_op, + loglik_kind="approx_differentiable", + p_outlier=p_outlier, + lapse=lapse, + link_settings=link_settings, + prior_settings=prior_settings, + extra_namespace=extra_namespace, + missing_data=missing_data, + deadline=deadline, + loglik_missing_data=loglik_missing_data, + process_initvals=process_initvals, + initval_jitter=initval_jitter, + **kwargs, + ) + + def _make_model_distribution(self) -> type[pm.Distribution]: + """Build a pm.Distribution using the pre-built RL log-likelihood Op. + + Unlike :meth:`HSSM._make_model_distribution`, this method does not go + through :func:`~hssm.distribution_utils.make_likelihood_callable`. + Instead it uses ``self.loglik`` directly — the differentiable pytensor + ``Op`` built in :meth:`__init__` from + ``self._rlssm_config.ssm_logp_func``. + + The Op already handles: + - The RL learning rule (computing trial-wise intermediate parameters). + - The per-participant / per-trial data reshaping. + - Gradient computation via its embedded VJP. + + Missing-data network assembly (OPN / CPN) is not yet supported for + RLSSM and ``missing_data`` / ``deadline`` are rejected in ``__init__`` + before this method is ever reached. + """ + # Build params_is_trialwise in the same order as self.list_params so the + # length always matches the list_params= argument passed to make_distribution. + # p_outlier is a scalar mixture weight (not trialwise); every other RLSSM + # parameter is trialwise (the Op receives one value per trial). + assert self.list_params is not None, "list_params should be set by HSSMBase" + params_is_trialwise = [name != "p_outlier" for name in self.list_params] + + extra_fields_data = ( + None + if not self.extra_fields + else [self.data[field].to_numpy(copy=True) for field in self.extra_fields] + ) + + # self.loglik was set to the pytensor Op built in __init__; cast to + # narrow the inherited union type so make_distribution's type-checker + # accepts it without a runtime penalty. + loglik_op = cast("Callable[..., Any] | Op", self.loglik) + return make_distribution( + rv=self.model_name, + loglik=loglik_op, + list_params=self.list_params, + bounds=self.bounds, + lapse=self.lapse, + extra_fields=extra_fields_data, + params_is_trialwise=params_is_trialwise, + ) + + def _get_prefix(self, name_str: str) -> str: + """Resolve parameter prefix, handling underscore-containing RL param names. + + The base-class implementation splits ``name_str`` on the first ``_`` and + returns that single token (e.g. ``"rl_alpha_Intercept" → "rl"``), which + breaks for RL parameters whose names contain underscores. It also uses a + substring check (``"p_outlier" in name_str``) for the lapse parameter, + which would misfire for any parameter whose name merely *contains* that + substring. + + This override replaces both heuristics with a single longest-prefix-first + token search: split on ``_``, then try joining 1…N tokens (longest first) + until a candidate is found in ``self.params``. This is both correct for + multi-token RL param names and collision-free for ``p_outlier``. + """ + if "_" in name_str: + parts = name_str.split("_") + for i in range(len(parts), 0, -1): + candidate = "_".join(parts[:i]) + if candidate in self.params: + return candidate + return super()._get_prefix(name_str) diff --git a/src/hssm/rl/utils.py b/src/hssm/rl/utils.py new file mode 100644 index 000000000..bbedaa7a8 --- /dev/null +++ b/src/hssm/rl/utils.py @@ -0,0 +1,56 @@ +"""Utility functions for reinforcement learning + SSM models.""" + +import pandas as pd + + +def validate_balanced_panel( + data: pd.DataFrame, + participant_col: str = "participant_id", +) -> tuple[int, int]: + """Validate that data forms a balanced panel and return its shape. + + A balanced panel requires every participant to have exactly the same number + of trials (rows in *data*). + + Parameters + ---------- + data : pd.DataFrame + The DataFrame to validate. + participant_col : str, optional + Name of the column that identifies participants. + Defaults to ``"participant_id"``. + + Returns + ------- + tuple[int, int] + ``(n_participants, n_trials)`` where *n_trials* is the number of rows + per participant. + + Raises + ------ + ValueError + If *participant_col* is not present in *data*, or if the panel is + unbalanced (participants have different trial counts). + """ + if participant_col not in data.columns: + raise ValueError( + f"Column '{participant_col}' not found in data. " + "Please provide the correct participant column name via " + "`participant_col`." + ) + + n_null = data[participant_col].isna().sum() + if n_null > 0: + raise ValueError( + f"Column '{participant_col}' contains {n_null} NaN value(s). " + "All rows must have a valid participant identifier." + ) + + counts = data.groupby(participant_col).size() + if counts.nunique() != 1: + raise ValueError( + "Data must form balanced panels: all participants must have the " + f"same number of trials. Observed trial counts: {dict(counts)}" + ) + + return int(len(counts)), int(counts.iloc[0]) diff --git a/tests/test_data_validator.py b/tests/test_data_validator.py index ecdc927fa..9efbdc05d 100644 --- a/tests/test_data_validator.py +++ b/tests/test_data_validator.py @@ -66,7 +66,7 @@ def test_constructor(base_data): assert isinstance(dv, DataValidatorMixin) assert dv.data.equals(_base_data()) assert dv.response == ["rt", "response"] - assert dv.choices == [0, 1] + assert dv.choices == (0, 1) assert dv.n_choices == 2 assert dv.extra_fields == ["extra"] assert dv.deadline is True diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py new file mode 100644 index 000000000..61937f168 --- /dev/null +++ b/tests/test_rlssm.py @@ -0,0 +1,311 @@ +"""Tests for the RLSSM class. + +Mirrors the structure of tests/test_hssm.py, covering initialisation, +config validation, param keys, balanced-panel enforcement, the no-lapse +variant, bambi / PyMC model construction, and a sampling smoke test. +""" + +from collections.abc import Generator +from pathlib import Path + +import jax.numpy as jnp +import numpy as np +import pandas as pd +import pytensor +import pytest + +import hssm +from hssm import RLSSM, RLSSMConfig +from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise +from hssm.utils import annotate_function + +# --------------------------------------------------------------------------- +# Module-level annotated helpers (shared by all tests) +# --------------------------------------------------------------------------- + +# Annotate the RL learning function: maps +# (rl_alpha, scaler, response, feedback) -> v +_compute_v_annotated = annotate_function( + inputs=["rl_alpha", "scaler", "response", "feedback"], + outputs=["v"], +)(compute_v_subject_wise) + + +# Annotated SSM log-likelihood function (simplified for testing). +# It receives a 2-D lan_matrix whose columns correspond to +# [v, a, z, t, theta, rt, response] +# and returns per-trial log-probabilities of shape (n_total_trials,). +@annotate_function( + inputs=["v", "a", "z", "t", "theta", "rt", "response"], + outputs=["logp"], + computed={"v": _compute_v_annotated}, +) +def _dummy_ssm_logp(lan_matrix: jnp.ndarray) -> jnp.ndarray: + """Return per-trial log-probabilities (column-sum); structural tests only.""" + # Return 1D (N,) — PyTensor declares the Op output as pt.vector(), so + # gradients arrive as (N,). A (N,1) return causes a VJP shape mismatch. + return jnp.sum(lan_matrix, axis=1) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module", autouse=True) +def _set_floatx_float32() -> Generator[None, None, None]: + """Ensure float32 is used for this module's tests, then restore previous setting.""" + prev_floatx = pytensor.config.floatX + hssm.set_floatX("float32", update_jax=True) + try: + yield + finally: + hssm.set_floatX(prev_floatx, update_jax=True) + + +@pytest.fixture(scope="module") +def rldm_data() -> pd.DataFrame: + """Load the RLDM fixture dataset (balanced panel).""" + raw = np.load( + Path(__file__).parent / "fixtures" / "rldm_data.npy", allow_pickle=True + ).item() + return pd.DataFrame(raw["data"]) + + +@pytest.fixture(scope="module") +def rlssm_config() -> RLSSMConfig: + """Minimal but valid RLSSMConfig for the RLDM fixture dataset.""" + return RLSSMConfig( + model_name="rldm_test", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_loglik_kind="blackbox", + list_params=["rl_alpha", "scaler", "a", "theta", "t", "z"], + params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5], + bounds={ + "rl_alpha": (0.0, 1.0), + "scaler": (0.0, 10.0), + "a": (0.1, 3.0), + "theta": (-0.1, 0.1), + "t": (0.001, 1.0), + "z": (0.1, 0.9), + }, + learning_process={"v": _compute_v_annotated}, + response=["rt", "response"], + choices=[0, 1], + extra_fields=["feedback"], + ssm_logp_func=_dummy_ssm_logp, + ) + + +# --------------------------------------------------------------------------- +# Initialisation & config-validation tests +# --------------------------------------------------------------------------- + + +def test_rlssm_init(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """Basic RLSSM initialisation should succeed and return an RLSSM instance.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + assert isinstance(model, RLSSM) + assert model.model_name == "rldm_test" + + +def test_rlssm_panel_attrs(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """_n_participants and _n_trials should match the fixture data structure.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + + n_participants = rldm_data["participant_id"].nunique() + n_trials = len(rldm_data) // n_participants + + assert model._n_participants == n_participants + assert model._n_trials == n_trials + + +def test_rlssm_params_keys(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """model.params should contain exactly list_params + p_outlier.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + expected = set(rlssm_config.list_params) | {"p_outlier"} + assert set(model.params.keys()) == expected + + +def test_rlssm_unbalanced_raises( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """Dropping one row should make the panel unbalanced → ValueError.""" + unbalanced = rldm_data.iloc[:-1].copy() + with pytest.raises(ValueError, match="balanced panels"): + RLSSM(data=unbalanced, rlssm_config=rlssm_config) + + +def test_rlssm_nan_participant_id_raises( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """NaN in participant_id column should raise ValueError before groupby silently drops rows.""" + nan_data = rldm_data.copy() + nan_data.loc[nan_data.index[0], "participant_id"] = float("nan") + with pytest.raises(ValueError, match="NaN"): + RLSSM(data=nan_data, rlssm_config=rlssm_config) + + +def test_rlssm_missing_ssm_logp_func_raises( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """RLSSMConfig without ssm_logp_func should raise ValueError on init.""" + bad_config = RLSSMConfig( + model_name="rldm_bad", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_loglik_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=list(rlssm_config.extra_fields), + # ssm_logp_func intentionally omitted → defaults to None + ) + with pytest.raises(ValueError, match="ssm_logp_func"): + RLSSM(data=rldm_data, rlssm_config=bad_config) + + +def test_rlssm_unannotated_ssm_logp_func_raises( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """A plain callable without @annotate_function attrs should raise ValueError.""" + bad_config = RLSSMConfig( + model_name="rldm_bad", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_loglik_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=list(rlssm_config.extra_fields), + ssm_logp_func=lambda x: x, # callable but no .inputs/.outputs/.computed + ) + with pytest.raises(ValueError, match="annotate_function"): + RLSSM(data=rldm_data, rlssm_config=bad_config) + + +def test_rlssm_missing_data_raises( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """Passing missing_data!=False should raise ValueError with 'missing_data' in msg.""" + with pytest.raises(ValueError, match="missing_data"): + RLSSM(data=rldm_data, rlssm_config=rlssm_config, missing_data=True) + + +def test_rlssm_deadline_raises( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """Passing deadline!=False should raise ValueError with 'deadline' in msg.""" + with pytest.raises(ValueError, match="deadline"): + RLSSM(data=rldm_data, rlssm_config=rlssm_config, deadline=True) + + +# --------------------------------------------------------------------------- +# Model-structure tests +# --------------------------------------------------------------------------- + + +def test_rlssm_params_is_trialwise_aligned( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """params_is_trialwise must align with list_params (same length, p_outlier=False).""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + assert model.list_params is not None + params_is_trialwise = [name != "p_outlier" for name in model.list_params] + assert len(params_is_trialwise) == len(model.list_params) + for name, is_tw in zip(model.list_params, params_is_trialwise): + if name == "p_outlier": + assert not is_tw, "p_outlier must be non-trialwise" + else: + assert is_tw, f"{name} must be trialwise" + + +def test_rlssm_get_prefix(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """_get_prefix must use token-based matching, not substring search. + + - 'rl_alpha_Intercept' → 'rl_alpha' (underscore-containing RL param) + - 'p_outlier_log__' → 'p_outlier' (lapse param via token loop, not substring) + - 'a_Intercept' → 'a' (single-token standard param) + """ + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + assert model._get_prefix("rl_alpha_Intercept") == "rl_alpha" + assert model._get_prefix("p_outlier_log__") == "p_outlier" + assert model._get_prefix("a_Intercept") == "a" + + +def test_rlssm_no_lapse(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """Setting p_outlier=None should remove p_outlier from params.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config, p_outlier=None) + assert "p_outlier" not in model.params + + +def test_rlssm_model_built(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """The bambi model should be built and the computed param 'v' absent from params.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + assert model.model is not None + # rl_alpha is a free (sampled) parameter + assert "rl_alpha" in model.params + # v is computed inside the Op; it must NOT appear as a free parameter + assert "v" not in model.params + + +def test_rlssm_extra_fields_are_copies( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """extra_fields passed to make_distribution must be independent numpy copies. + + to_numpy(copy=True) should return a new buffer; if it returned a view, + in-place mutations of the DataFrame would silently corrupt the distribution. + """ + from unittest.mock import patch + + from hssm.distribution_utils import make_distribution as real_make_distribution + + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + captured: dict = {} + + def capturing_make_distribution(*args, **kwargs): + captured["extra_fields"] = kwargs.get("extra_fields") + return real_make_distribution(*args, **kwargs) + + with patch( + "hssm.rl.rlssm.make_distribution", side_effect=capturing_make_distribution + ): + model._make_model_distribution() + + assert captured.get("extra_fields") is not None + for field_name, arr in zip(rlssm_config.extra_fields, captured["extra_fields"]): + original = model.data[field_name].to_numpy() + assert not np.shares_memory(arr, original), ( + f"extra_fields['{field_name}'] shares memory with the DataFrame — " + "it is a view, not a copy" + ) + + +def test_rlssm_pymc_model(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """pymc_model should be accessible after model construction.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + assert model.pymc_model is not None + + +# --------------------------------------------------------------------------- +# Slow sampling smoke test +# --------------------------------------------------------------------------- + + +@pytest.mark.slow +def test_rlssm_sample_smoke(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: + """Minimal sampling run should return an InferenceData object.""" + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + trace = model.sample(draws=2, tune=2, chains=1, cores=1) + assert trace is not None diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 046ee6e4e..55cbbe244 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -1,8 +1,8 @@ import pytest import hssm -from hssm.config import Config, RLSSMConfig -from hssm.config import ModelConfig +from hssm.config import Config, ModelConfig, RLSSMConfig +from hssm.utils import annotate_function # Define constants for repeated data structures DEFAULT_RESPONSE = ("rt", "response") @@ -16,6 +16,13 @@ } +# Module-level annotated dummy used wherever from_rlssm_dict needs a valid +# ssm_logp_func but the test is not about ssm_logp_func itself. +@annotate_function(inputs=["v", "rt", "response"], outputs=["logp"], computed={}) +def _module_dummy_ssm_logp(x): + return x + + # Helper function to create a config dictionary def create_config_dict( model_name, @@ -29,6 +36,7 @@ def create_config_dict( decision_process="ddm", decision_process_loglik_kind="analytical", learning_process_loglik_kind="blackbox", + ssm_logp_func=_module_dummy_ssm_logp, ): return dict( model_name=model_name, @@ -44,6 +52,7 @@ def create_config_dict( decision_process=decision_process, decision_process_loglik_kind=decision_process_loglik_kind, learning_process_loglik_kind=learning_process_loglik_kind, + ssm_logp_func=ssm_logp_func, data={}, ) @@ -51,6 +60,10 @@ def create_config_dict( # region fixtures and helpers @pytest.fixture def valid_rlssmconfig_kwargs(): + @annotate_function(inputs=["v", "rt", "response"], outputs=["logp"], computed={}) + def _dummy_ssm_logp_func(x): + return x + return dict( model_name="test_model", list_params=["alpha", "beta"], @@ -62,6 +75,7 @@ def valid_rlssmconfig_kwargs(): decision_process_loglik_kind="analytical", learning_process_loglik_kind="blackbox", learning_process={}, + ssm_logp_func=_dummy_ssm_logp_func, ) @@ -149,7 +163,7 @@ def test_from_rlssm_dict_cases( expected_choices, expected_learning_process, ): - config = RLSSMConfig.from_rlssm_dict(model_name, config_dict) + config = RLSSMConfig.from_rlssm_dict(config_dict) assert config.model_name == expected_model_name assert config.params_default == expected_params_default assert config.bounds == expected_bounds @@ -164,8 +178,9 @@ class TestRLSSMConfigValidation: [ ("response", None, "Please provide `response` columns"), ("list_params", None, "Please provide `list_params"), - ("choices", None, "Please provide `choices"), - ("decision_process", None, "Please specify a `decision_process"), + ("choices", None, "Please provide `choices`"), + ("decision_process", None, "Please specify a `decision_process`"), + ("ssm_logp_func", None, "Please provide `ssm_logp_func`"), ], ) def test_validate_missing_fields( @@ -201,17 +216,12 @@ def test_validate_success(self, valid_rlssmconfig_kwargs): config = RLSSMConfig(**valid_rlssmconfig_kwargs) config.validate() - def test_validate_params_default_mismatch(self): + def test_validate_params_default_mismatch(self, valid_rlssmconfig_kwargs): config = RLSSMConfig( - model_name="test_model", - list_params=["alpha", "beta"], - params_default=[0.5], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, + **{ + **valid_rlssmconfig_kwargs, + "params_default": [0.5], # length 1, but list_params has 2 entries + } ) with pytest.raises( ValueError, @@ -219,6 +229,21 @@ def test_validate_params_default_mismatch(self): ): config.validate() + def test_validate_ssm_logp_func_not_callable(self, valid_rlssmconfig_kwargs): + config = RLSSMConfig(**valid_rlssmconfig_kwargs) + config.ssm_logp_func = "not_a_callable" + with pytest.raises(ValueError, match="must be a callable"): + config.validate() + + def test_validate_ssm_logp_func_missing_annotations(self, valid_rlssmconfig_kwargs): + config = RLSSMConfig(**valid_rlssmconfig_kwargs) + # Replace with a plain callable that lacks @annotate_function attributes + config.ssm_logp_func = lambda x: x + with pytest.raises( + ValueError, match="must be decorated with `@annotate_function`" + ): + config.validate() + class TestRLSSMConfigDefaults: @pytest.mark.parametrize( @@ -464,11 +489,33 @@ def test_from_rlssm_dict_missing_required(self): "bounds": {}, "data": {}, "extra_fields": [], + "ssm_logp_func": _module_dummy_ssm_logp, } with pytest.raises( ValueError, match="decision_process_loglik_kind must be provided" ): - RLSSMConfig.from_rlssm_dict("test_model", config_dict) + RLSSMConfig.from_rlssm_dict(config_dict) + + def test_from_rlssm_dict_missing_ssm_logp_func(self): + # Should raise ValueError at construction time if ssm_logp_func is missing + config_dict = { + "model_name": "test_model", + "name": "test_model", + "list_params": ["alpha"], + "params_default": [0.0], + "decision_process": "ddm", + "learning_process": {}, + "learning_process_loglik_kind": "blackbox", + "decision_process_loglik_kind": "analytical", + "response": ["rt", "response"], + "choices": [0, 1], + "description": "desc", + "bounds": {}, + "data": {}, + "extra_fields": [], + } + with pytest.raises(ValueError, match="ssm_logp_func must be provided"): + RLSSMConfig.from_rlssm_dict(config_dict) def test_missing_decision_process_loglik_kind(self): with pytest.raises(TypeError): @@ -492,11 +539,12 @@ def test_missing_decision_process_loglik_kind(self): "response": ["rt", "response"], "choices": [0, 1], "extra_fields": [], + "ssm_logp_func": _module_dummy_ssm_logp, } with pytest.raises( ValueError, match="decision_process_loglik_kind must be provided" ): - RLSSMConfig.from_rlssm_dict("test_model", config_dict) + RLSSMConfig.from_rlssm_dict(config_dict) def test_with_modelconfig_decision_process(self): decision_config = ModelConfig(