From 9e5e8be1483ffad8925f036db554c8213499d34a Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 11:24:12 -0400 Subject: [PATCH 01/44] Add deprecation warnings for model_config attributes in HSSMBase --- src/hssm/base.py | 117 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/src/hssm/base.py b/src/hssm/base.py index 1b4e1db5..2b0e415e 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -457,6 +457,123 @@ def _make_model_distribution(self) -> type[pm.Distribution]: """ ... + def _deprecation_warn(self, name: str) -> None: + """Emit a DeprecationWarning advising to use the typed config. + + Parameters + ---------- + name + Attribute name being deprecated. + """ + warnings.warn( + f"`{name}` is deprecated; use `self.model_config.{name}` instead.", + DeprecationWarning, + stacklevel=2, + ) + + @property + def response(self): + """Deprecated proxy for `self.model_config.response`. + + Returns the list of observed response column names or None. + """ + self._deprecation_warn("response") + return ( + list(self.model_config.response) + if self.model_config.response is not None + else None + ) + + @response.setter + def response(self, value): + """Set the model_config.response value (deprecated). + + Converts the assigned value to a list or None before assignment. + """ + self._deprecation_warn("response") + self.model_config.response = list(value) if value is not None else None + + @property + def list_params(self): + """Deprecated proxy for `self.model_config.list_params`. + + Returns the parameter name list used by the likelihood. + """ + self._deprecation_warn("list_params") + return self.model_config.list_params + + @list_params.setter + def list_params(self, value): + """Set the model_config.list_params value (deprecated).""" + self._deprecation_warn("list_params") + self.model_config.list_params = value + + @property + def choices(self): + """Deprecated proxy for `self.model_config.choices`. + + Returns a tuple of valid response choices. + """ + self._deprecation_warn("choices") + return self.model_config.choices + + @choices.setter + def choices(self, value): + """Set the model_config.choices value (deprecated).""" + self._deprecation_warn("choices") + self.model_config.choices = value + + @property + def model_name(self): + """Deprecated proxy for `self.model_config.model_name`.""" + self._deprecation_warn("model_name") + return self.model_config.model_name + + @model_name.setter + def model_name(self, value): + """Set the model_config.model_name value (deprecated).""" + self._deprecation_warn("model_name") + self.model_config.model_name = value + + @property + def loglik(self): + """Deprecated proxy for `self.model_config.loglik`. + + Returns the configured log-likelihood callable or object. + """ + self._deprecation_warn("loglik") + return self.model_config.loglik + + @loglik.setter + def loglik(self, value): + """Set the model_config.loglik value (deprecated).""" + self._deprecation_warn("loglik") + self.model_config.loglik = value + + @property + def loglik_kind(self): + """Deprecated proxy for `self.model_config.loglik_kind`.""" + self._deprecation_warn("loglik_kind") + return self.model_config.loglik_kind + + @loglik_kind.setter + def loglik_kind(self, value): + """Set the model_config.loglik_kind value (deprecated).""" + self._deprecation_warn("loglik_kind") + self.model_config.loglik_kind = value + + @property + def extra_fields(self): + """Deprecated proxy for `self.model_config.extra_fields`.""" + self._deprecation_warn("extra_fields") + return self.model_config.extra_fields + + @extra_fields.setter + def extra_fields(self, value): + """Set the model_config.extra_fields value (deprecated).""" + self._deprecation_warn("extra_fields") + self.model_config.extra_fields = value + def _fix_scalar_deterministic_dims(self) -> None: """Fix dims metadata for scalar deterministics. From 74acc67b4ba2beec6ed391863e8f49766d1c03f5 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 11:25:00 -0400 Subject: [PATCH 02/44] Refactor HSSMBase to support BaseModelConfig and improve model_config handling --- src/hssm/base.py | 126 ++++++++--------------------------------------- 1 file changed, 20 insertions(+), 106 deletions(-) diff --git a/src/hssm/base.py b/src/hssm/base.py index 2b0e415e..39a9846d 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -8,6 +8,7 @@ import datetime import logging +import warnings from abc import ABC, abstractmethod from copy import deepcopy from inspect import signature @@ -29,7 +30,6 @@ from bambi.model_components import DistributionalComponent from bambi.transformations import transformations_namespace from pymc.model.transform.conditioning import do -from ssms.config import model_config as ssms_model_config from hssm._types import LoglikKind, SupportedModels from hssm.data_validator import DataValidatorMixin @@ -49,7 +49,7 @@ ) from . import plotting -from .config import Config, ModelConfig +from .config import BaseModelConfig, ModelConfig from .param import Params from .param import UserParam as Param @@ -323,26 +323,15 @@ def __init__( self.initval_jitter = initval_jitter # region ===== Construct a model_config from defaults and user inputs ===== - self.model_config: Config = self._build_model_config( - model, loglik_kind, model_config, choices + self.model_config: BaseModelConfig = self._build_model_config( + model, loglik_kind, model_config, choices, loglik ) - self.model_config.update_loglik(loglik) - self.model_config.validate() # endregion - # region ===== Set up shortcuts so old code will work ====== - self.response = ( - list(self.model_config.response) - if self.model_config.response is not None - else None - ) - self.list_params = self.model_config.list_params - self.choices = self.model_config.choices # type: ignore[assignment] - self.model_name = self.model_config.model_name - self.loglik = self.model_config.loglik - self.loglik_kind = self.model_config.loglik_kind - self.extra_fields = self.model_config.extra_fields - # endregion + # Previously scalar shortcuts (e.g. `self.list_params`) were set here. + # These are now provided as deprecated proxy properties that forward to + # the authoritative `self.model_config` object. See property definitions + # below. self._validate_choices() @@ -624,94 +613,19 @@ def _build_model_config( loglik_kind: LoglikKind | None, model_config: ModelConfig | dict | None, choices: list[int] | None, - ) -> Config: - """Build a ModelConfig object from defaults and user inputs. - - Parameters - ---------- - model : SupportedModels | str - The model name. - loglik_kind : LoglikKind | None - The kind of likelihood function. - model_config : ModelConfig | dict | None - User-provided model configuration. - choices : list[int] | None - User-provided choices list. - - Returns - ------- - Config - A complete Config object with choices and other settings applied. + loglik: Any = None, + ) -> BaseModelConfig: + """Delegate config building to the appropriate config-family builder. + + Calls ``_build_model_config`` on the config class returned by + ``get_config_class()``, resolved via the MRO of the calling model + class (e.g. ``HSSM`` → ``Config``, ``RLSSM`` → ``RLSSMConfig``). + The family builder handles defaults resolution, dict normalization, + loglik/choices precedence, and final validation. """ - # Start with defaults - # get_config_class is provided by Config/RLSSMConfig mixin through MRO - config = cls.get_config_class().from_defaults(model, loglik_kind) # type: ignore[attr-defined] - - # Handle user-provided model_config - if model_config is not None: - # Check if choices already exists in the provided config - has_choices = ( - isinstance(model_config, dict) - and "choices" in model_config - or isinstance(model_config, ModelConfig) - and model_config.choices is not None - ) - - # Handle choices conflict or missing choices - if choices is not None: - if has_choices: - _logger.info( - "choices list provided in both model_config and " - "as an argument directly." - " Using the one provided in model_config. \n" - "We recommend providing choices in model_config." - ) - else: - # Add choices to a copy of the config to avoid mutating input - if isinstance(model_config, dict): - model_config = {**model_config, "choices": choices} - else: # ModelConfig instance - # Create a dict from the ModelConfig and add choices - model_config_dict = { - k: getattr(model_config, k) - for k in model_config.__dataclass_fields__ - if getattr(model_config, k) is not None - } - model_config_dict["choices"] = choices - model_config = model_config_dict - - # Convert dict to ModelConfig if needed and update - final_config = ( - model_config - if isinstance(model_config, ModelConfig) - else ModelConfig(**model_config) - ) - config.update_config(final_config) - - # Handle default config (no model_config provided) - else: - # For supported models, defaults already have choices - if model in get_args(SupportedModels): - if choices is not None: - _logger.info( - "Model string is in SupportedModels." - " Ignoring choices arguments." - ) - # For custom models, try to get choices - else: - if choices is not None: - config.update_choices(choices) - elif model in ssms_model_config: - config.update_choices(ssms_model_config[model]["choices"]) - _logger.info( - "choices argument passed as None, " - "but found %s in ssms-simulators. " - "Using choices, from ssm-simulators configs: %s", - model, - ssms_model_config[model]["choices"], - ) - - return config + return cls.get_config_class()._build_model_config( # type: ignore[attr-defined] + model, loglik_kind, model_config, choices, loglik + ) @classproperty def supported_models(cls) -> tuple[SupportedModels, ...]: From 2acc19dadd09d902e24111a9e0e100c764cc0fb3 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 11:27:22 -0400 Subject: [PATCH 03/44] Add model configuration building methods to BaseModelConfig and Config classes --- src/hssm/config.py | 122 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/src/hssm/config.py b/src/hssm/config.py index 31415df9..60e0dcdf 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -20,6 +20,12 @@ if TYPE_CHECKING: from pytensor.tensor.random.op import RandomVariable +import logging + +from ssms.config import model_config as ssms_model_config + +_logger = logging.getLogger("hssm") + # ====== Centralized RLSSM defaults ===== DEFAULT_SSM_OBSERVED_DATA = ["rt", "response"] DEFAULT_RLSSM_OBSERVED_DATA = ["rt", "response"] @@ -88,6 +94,27 @@ def get_defaults(self, param: str) -> Any: """Get default values for a parameter. Must be implemented by subclasses.""" ... + @classmethod + @abstractmethod + def _build_model_config( + cls, + model: "SupportedModels | str", + loglik_kind: "LoglikKind | None", + model_config: "ModelConfig | dict | None", + choices: "list[int] | None", + loglik: Any = None, + ) -> "BaseModelConfig": + """Build and return a fully validated config for this model family. + + Family builders are responsible for: + + - Resolving defaults via ``from_defaults``, + - Normalizing user overrides (dict or typed config) into the family type, + - Applying choices/loglik precedence rules, + - Calling ``validate()`` before returning. + """ + ... + @dataclass class Config(BaseModelConfig): @@ -275,6 +302,80 @@ def get_defaults( """ return self.default_priors.get(param), self.bounds.get(param) + @classmethod + def _build_model_config( + cls, + model: "SupportedModels | str", + loglik_kind: "LoglikKind | None", + model_config: "ModelConfig | dict | None", + choices: "list[int] | None", + loglik: Any = None, + ) -> "Config": + """Build and return a validated Config for standard HSSM models. + + Resolves defaults, normalizes dict/ModelConfig overrides, applies + choices and loglik precedence rules, then validates before returning. + """ + config = cls.from_defaults(model, loglik_kind) + + if model_config is not None: + has_choices = ( + isinstance(model_config, dict) + and "choices" in model_config + or isinstance(model_config, ModelConfig) + and model_config.choices is not None + ) + if choices is not None: + if has_choices: + _logger.info( + "choices list provided in both model_config and " + "as an argument directly." + " Using the one provided in model_config. \n" + "We recommend providing choices in model_config." + ) + else: + if isinstance(model_config, dict): + model_config = {**model_config, "choices": choices} + else: + model_config_dict = { + k: getattr(model_config, k) + for k in model_config.__dataclass_fields__ + if getattr(model_config, k) is not None + } + model_config_dict["choices"] = choices + model_config = model_config_dict + + final_config = ( + model_config + if isinstance(model_config, ModelConfig) + else ModelConfig(**model_config) + ) + config.update_config(final_config) + + else: + if model in get_args(SupportedModels): + if choices is not None: + _logger.info( + "Model string is in SupportedModels." + " Ignoring choices arguments." + ) + else: + if choices is not None: + config.update_choices(choices) + elif model in ssms_model_config: + config.update_choices(ssms_model_config[model]["choices"]) + _logger.info( + "choices argument passed as None, " + "but found %s in ssms-simulators. " + "Using choices, from ssm-simulators configs: %s", + model, + ssms_model_config[model]["choices"], + ) + + config.update_loglik(loglik) + config.validate() + return config + @dataclass class RLSSMConfig(BaseModelConfig): @@ -310,6 +411,27 @@ def get_config_class(cls) -> type["RLSSMConfig"]: """Return RLSSMConfig as the config class for RLSSM models.""" return RLSSMConfig + @classmethod + def _build_model_config( + cls, + model: "SupportedModels | str", + loglik_kind: "LoglikKind | None", + model_config: "ModelConfig | dict | None", + choices: "list[int] | None", + loglik: Any = None, + ) -> "Config": + """Build a validated Config for RLSSM by delegating to Config's builder. + + RLSSM constructs and validates the full RLSSMConfig in its own + ``__init__`` before calling ``HSSMBase.__init__``. The model_config + passed here is a ModelConfig derived from + ``RLSSMConfig.to_model_config()``, so this method delegates directly + to ``Config._build_model_config`` for normalization. + """ + return Config._build_model_config( + model, loglik_kind, model_config, choices, loglik + ) + @classmethod def from_defaults( cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None From 97bf90f68ba823af4c15796aa256c0a43c167d66 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 11:39:50 -0400 Subject: [PATCH 04/44] Refactor model configuration handling in HSSMBase and HSSM classes to delegate config building and improve attribute access --- src/hssm/base.py | 28 +++++----------------------- src/hssm/hssm.py | 2 +- 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/src/hssm/base.py b/src/hssm/base.py index 39a9846d..b55fa2ff 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -323,8 +323,11 @@ def __init__( self.initval_jitter = initval_jitter # region ===== Construct a model_config from defaults and user inputs ===== - self.model_config: BaseModelConfig = self._build_model_config( - model, loglik_kind, model_config, choices, loglik + # Delegate to the config-family builder (Config, RLSSMConfig, etc.) via MRO. + self.model_config: BaseModelConfig = ( + self.get_config_class()._build_model_config( # type: ignore[attr-defined] + model, loglik_kind, model_config, choices, loglik + ) ) # endregion @@ -606,27 +609,6 @@ def _validate_fixed_vectors(self) -> None: f"{len(param.prior)}, but data has {len(self.data)} rows." ) - @classmethod - def _build_model_config( - cls, - model: SupportedModels | str, - loglik_kind: LoglikKind | None, - model_config: ModelConfig | dict | None, - choices: list[int] | None, - loglik: Any = None, - ) -> BaseModelConfig: - """Delegate config building to the appropriate config-family builder. - - Calls ``_build_model_config`` on the config class returned by - ``get_config_class()``, resolved via the MRO of the calling model - class (e.g. ``HSSM`` → ``Config``, ``RLSSM`` → ``RLSSMConfig``). - The family builder handles defaults resolution, dict normalization, - loglik/choices precedence, and final validation. - """ - return cls.get_config_class()._build_model_config( # type: ignore[attr-defined] - model, loglik_kind, model_config, choices, loglik - ) - @classproperty def supported_models(cls) -> tuple[SupportedModels, ...]: """Get a tuple of all supported models. diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 45224cd2..2a5431bf 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -368,7 +368,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]: assert self.list_params is not None, "list_params should be set" return make_distribution( - rv=self.model_config.rv or self.model_name, + rv=getattr(self.model_config, "rv", None) or self.model_name, loglik=self.loglik, list_params=self.list_params, bounds=self.bounds, From f75f5e4e262acc878f681c72302e8ed12bf36918 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 12:18:12 -0400 Subject: [PATCH 05/44] Add properties to BaseModelConfig for parameter and extra field counts --- src/hssm/config.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index 60e0dcdf..14e4e163 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -94,6 +94,16 @@ def get_defaults(self, param: str) -> Any: """Get default values for a parameter. Must be implemented by subclasses.""" ... + @property + def n_params(self) -> int | None: + """Return the number of parameters.""" + return len(self.list_params) if self.list_params else None + + @property + def n_extra_fields(self) -> int | None: + """Return the number of extra fields.""" + return len(self.extra_fields) if self.extra_fields else None + @classmethod @abstractmethod def _build_model_config( @@ -439,16 +449,6 @@ def from_defaults( """Return the shared Config defaults (delegated to :class:`Config`).""" return Config.from_defaults(model_name, loglik_kind) - @property - def n_params(self) -> int | None: - """Return the number of parameters.""" - return len(self.list_params) if self.list_params else None - - @property - def n_extra_fields(self) -> int | None: - """Return the number of extra fields.""" - return len(self.extra_fields) if self.extra_fields else None - @classmethod def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> "RLSSMConfig": """ From b9b48394587ebb31565e0fff5753704160b8861d Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 12:18:32 -0400 Subject: [PATCH 06/44] Refactor RLSSM attributes to use public naming convention for configuration and participant/trial counts --- src/hssm/rl/rlssm.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 339dd870..0a685f99 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -96,11 +96,11 @@ class RLSSM(HSSMBase, RLSSMConfig): Attributes ---------- - _rlssm_config : RLSSMConfig + config : RLSSMConfig The RLSSM configuration object. - _n_participants : int + n_participants : int Number of participants inferred from *data*. - _n_trials : int + n_trials : int Number of trials per participant inferred from *data*. """ @@ -152,9 +152,9 @@ def __init__( # Store RL-specific state on self BEFORE super().__init__() so that # _make_model_distribution() (called from super) can access them. - self._rlssm_config = rlssm_config - self._n_participants = n_participants - self._n_trials = n_trials + self.config = rlssm_config + self.n_participants = n_participants + self.n_trials = n_trials # Build the differentiable pytensor Op from the annotated SSM function. # This Op supersedes the loglik/loglik_kind workflow: it is passed as @@ -208,7 +208,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]: through :func:`~hssm.distribution_utils.make_likelihood_callable`. Instead it uses ``self.loglik`` directly — the differentiable pytensor ``Op`` built in :meth:`__init__` from - ``self._rlssm_config.ssm_logp_func``. + ``self.config.ssm_logp_func``. The Op already handles: - The RL learning rule (computing trial-wise intermediate parameters). From 821392fdbe0d7041cade8af19d59c86aa47185ab Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 12:18:38 -0400 Subject: [PATCH 07/44] Refactor test_rlssm_panel_attrs to use public attributes for participant and trial counts --- tests/test_rlssm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index 61937f16..6cd67e90 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -112,14 +112,14 @@ def test_rlssm_init(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: def test_rlssm_panel_attrs(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: - """_n_participants and _n_trials should match the fixture data structure.""" + """n_participants and n_trials should match the fixture data structure.""" model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) n_participants = rldm_data["participant_id"].nunique() n_trials = len(rldm_data) // n_participants - assert model._n_participants == n_participants - assert model._n_trials == n_trials + assert model.n_participants == n_participants + assert model.n_trials == n_trials def test_rlssm_params_keys(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: From 07169c97b80febb161ec96dffb90e4bbefc75158 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 14:00:27 -0400 Subject: [PATCH 08/44] Refactor HSSMBase to streamline model configuration handling and update initialization parameters --- src/hssm/base.py | 129 ++++++++++++++--------------------------------- 1 file changed, 38 insertions(+), 91 deletions(-) diff --git a/src/hssm/base.py b/src/hssm/base.py index b55fa2ff..d80cb590 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -11,7 +11,6 @@ import warnings from abc import ABC, abstractmethod from copy import deepcopy -from inspect import signature from os import PathLike from pathlib import Path from typing import Any, Callable, Literal, Optional, Union, cast, get_args @@ -31,7 +30,7 @@ from bambi.transformations import transformations_namespace from pymc.model.transform.conditioning import do -from hssm._types import LoglikKind, SupportedModels +from hssm._types import SupportedModels from hssm.data_validator import DataValidatorMixin from hssm.defaults import ( INITVAL_JITTER_SETTINGS, @@ -49,7 +48,7 @@ ) from . import plotting -from .config import BaseModelConfig, ModelConfig +from .config import BaseModelConfig from .param import Params from .param import UserParam as Param @@ -116,65 +115,12 @@ class HSSMBase(ABC, DataValidatorMixin, MissingDataMixin): A list of dictionaries specifying parameter specifications to include in the model. If left unspecified, defaults will be used for all parameter specifications. Defaults to None. - model_config : optional - A dictionary containing the model configuration information. If None is - provided, defaults will be used if there are any. Defaults to None. - Fields for this `dict` are usually: - - - `"list_params"`: a list of parameters indicating the parameters of the model. - The order in which the parameters are specified in this list is important. - Values for each parameter will be passed to the likelihood function in this - order. - - `"backend"`: Only used when `loglik_kind` is `approx_differentiable` and - an onnx file is supplied for the likelihood approximation network (LAN). - Valid values are `"jax"` or `"pytensor"`. It determines whether the LAN in - ONNX should be converted to `"jax"` or `"pytensor"`. If not provided, - `jax` will be used for maximum performance. - - `"default_priors"`: A `dict` indicating the default priors for each parameter. - - `"bounds"`: A `dict` indicating the boundaries for each parameter. In the case - of LAN, these bounds are training boundaries. - - `"rv"`: Optional. Can be a `RandomVariable` class containing the user's own - `rng_fn` function for sampling from the distribution that the user is - supplying. If not supplied, HSSM will automatically generate a - `RandomVariable` using the simulator identified by `model` from the - `ssm_simulators` package. If `model` is not supported in `ssm_simulators`, - a warning will be raised letting the user know that sampling from the - `RandomVariable` will result in errors. - - `"extra_fields"`: Optional. A list of strings indicating the additional - columns in `data` that will be passed to the likelihood function for - calculation. This is helpful if the likelihood function depends on data - other than the observed data and the parameter values. - loglik : optional - A likelihood function. Defaults to None. Requirements are: - - 1. if `loglik_kind` is `"analytical"` or `"blackbox"`, a pm.Distribution, a - pytensor Op, or a Python callable can be used. Signatures are: - - `pm.Distribution`: needs to have parameters specified exactly as listed in - `list_params` - - `pytensor.graph.Op` and `Callable`: needs to accept the parameters - specified exactly as listed in `list_params` - 2. If `loglik_kind` is `"approx_differentiable"`, then in addition to the - specifications above, a `str` or `Pathlike` can also be used to specify a - path to an `onnx` file. If a `str` is provided, HSSM will first look locally - for an `onnx` file. If that is not successful, HSSM will try to download - that `onnx` file from Hugging Face hub. - 3. It can also be `None`, in which case a default likelihood function will be - used - loglik_kind : optional - A string that specifies the kind of log-likelihood function specified with - `loglik`. Defaults to `None`. Can be one of the following: - - - `"analytical"`: an analytical (approximation) likelihood function. It is - differentiable and can be used with samplers that requires differentiation. - - `"approx_differentiable"`: a likelihood approximation network (LAN) likelihood - function. It is differentiable and can be used with samplers that requires - differentiation. - - `"blackbox"`: a black box likelihood function. It is typically NOT - differentiable. - - `None`, in which a default will be used. For `ddm` type of models, the default - will be `analytical`. For other models supported, it will be - `approx_differentiable`. If the model is a custom one, a ValueError - will be raised. + model_config + A fully-initialised :class:`~hssm.config.BaseModelConfig` instance + (e.g. :class:`~hssm.config.Config` or + :class:`~hssm.config.RLSSMConfig`) produced by the subclass before + calling ``super().__init__``. All likelihood, parameter, and data + information is drawn from this object. p_outlier : optional The fixed lapse probability or the prior distribution of the lapse probability. Defaults to a fixed value of 0.05. When `None`, the lapse probability will not @@ -267,14 +213,8 @@ class HSSMBase(ABC, DataValidatorMixin, MissingDataMixin): def __init__( self, data: pd.DataFrame, - model: SupportedModels | str = "ddm", - choices: list[int] | None = None, + model_config: BaseModelConfig, include: list[dict[str, Any] | Param] | None = None, - model_config: ModelConfig | dict | None = None, - loglik: ( - str | PathLike | Callable | pytensor.graph.Op | type[pm.Distribution] | None - ) = None, - loglik_kind: LoglikKind | None = None, p_outlier: float | dict | bmb.Prior | None = 0.05, lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0), global_formula: str | None = None, @@ -290,14 +230,6 @@ def __init__( initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], **kwargs, ): - # ===== init args for save/load models ===== - self._init_args = { - k: v for k, v in locals().items() if k not in ["self", "kwargs"] - } - if kwargs: - self._init_args.update(kwargs) - # endregion - # ===== Input Data & Configuration ===== self.data = data.copy() self.global_formula = global_formula @@ -322,13 +254,8 @@ def __init__( self._initvals: dict[str, Any] = {} self.initval_jitter = initval_jitter - # region ===== Construct a model_config from defaults and user inputs ===== - # Delegate to the config-family builder (Config, RLSSMConfig, etc.) via MRO. - self.model_config: BaseModelConfig = ( - self.get_config_class()._build_model_config( # type: ignore[attr-defined] - model, loglik_kind, model_config, choices, loglik - ) - ) + # region ===== Store the pre-built config ===== + self.model_config: BaseModelConfig = model_config # endregion # Previously scalar shortcuts (e.g. `self.list_params`) were set here. @@ -620,13 +547,33 @@ def supported_models(cls) -> tuple[SupportedModels, ...]: """ return get_args(SupportedModels) - @classmethod - def _store_init_args(cls, *args, **kwargs): - """Store initialization arguments using signature binding.""" - sig = signature(cls.__init__) - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - return {k: v for k, v in bound_args.arguments.items() if k != "self"} + @staticmethod + def _store_init_args( + local_vars: dict[str, Any], extra_kwargs: dict[str, Any] + ) -> dict[str, Any]: + """Capture subclass ``__init__`` arguments for save/load serialisation. + + Call this at the very start of a subclass ``__init__`` before any local + variables are assigned, passing ``locals()`` and the ``**kwargs`` dict:: + + self._init_args = self._store_init_args(locals(), kwargs) + + Parameters + ---------- + local_vars + The ``locals()`` snapshot from the subclass ``__init__``. + extra_kwargs + The ``**kwargs`` dict captured by the subclass ``__init__``. + + Returns + ------- + dict[str, Any] + A mapping of parameter names to their values, suitable for + reconstructing the instance via ``cls(**init_args)``. + """ + result = {k: v for k, v in local_vars.items() if k not in ("self", "kwargs")} + result.update(extra_kwargs) + return result def find_MAP(self, **kwargs): """Perform Maximum A Posteriori estimation. From 0395ec22692f2a38135d527ad154e1a976287b57 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 14:00:59 -0400 Subject: [PATCH 09/44] Refactor BaseModelConfig and RLSSMConfig by removing unused abstract methods and adding a new method for building validated Config instances --- src/hssm/config.py | 92 +++++++++++++++------------------------------- 1 file changed, 29 insertions(+), 63 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index 14e4e163..ba4500da 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -74,16 +74,6 @@ class BaseModelConfig(ABC): # Additional data requirements extra_fields: list[str] | None = None - @classmethod - @abstractmethod - def get_config_class(cls) -> type["BaseModelConfig"]: - """Return the config class for this model type. - - This enables polymorphic config resolution without circular imports. - Each subclass returns itself as the config class. - """ - ... - @abstractmethod def validate(self) -> None: """Validate configuration. Must be implemented by subclasses.""" @@ -104,27 +94,6 @@ def n_extra_fields(self) -> int | None: """Return the number of extra fields.""" return len(self.extra_fields) if self.extra_fields else None - @classmethod - @abstractmethod - def _build_model_config( - cls, - model: "SupportedModels | str", - loglik_kind: "LoglikKind | None", - model_config: "ModelConfig | dict | None", - choices: "list[int] | None", - loglik: Any = None, - ) -> "BaseModelConfig": - """Build and return a fully validated config for this model family. - - Family builders are responsible for: - - - Resolving defaults via ``from_defaults``, - - Normalizing user overrides (dict or typed config) into the family type, - - Applying choices/loglik precedence rules, - - Calling ``validate()`` before returning. - """ - ... - @dataclass class Config(BaseModelConfig): @@ -139,11 +108,6 @@ def __post_init__(self): if self.loglik_kind is None: raise ValueError("loglik_kind is required for Config") - @classmethod - def get_config_class(cls) -> type["Config"]: - """Return Config as the config class for HSSM models.""" - return Config - @classmethod def from_defaults( cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None @@ -416,32 +380,6 @@ def __post_init__(self): if self.loglik_kind is None: self.loglik_kind = "approx_differentiable" - @classmethod - def get_config_class(cls) -> type["RLSSMConfig"]: - """Return RLSSMConfig as the config class for RLSSM models.""" - return RLSSMConfig - - @classmethod - def _build_model_config( - cls, - model: "SupportedModels | str", - loglik_kind: "LoglikKind | None", - model_config: "ModelConfig | dict | None", - choices: "list[int] | None", - loglik: Any = None, - ) -> "Config": - """Build a validated Config for RLSSM by delegating to Config's builder. - - RLSSM constructs and validates the full RLSSMConfig in its own - ``__init__`` before calling ``HSSMBase.__init__``. The model_config - passed here is a ModelConfig derived from - ``RLSSMConfig.to_model_config()``, so this method delegates directly - to ``Config._build_model_config`` for normalization. - """ - return Config._build_model_config( - model, loglik_kind, model_config, choices, loglik - ) - @classmethod def from_defaults( cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None @@ -580,7 +518,7 @@ def get_defaults( return default_val, self.bounds.get(param) - def to_config(self) -> Config: + def to_config(self) -> "Config": """Convert to standard Config for compatibility with HSSM. This method transforms the RLSSM configuration into a standard Config @@ -652,6 +590,34 @@ def to_model_config(self) -> "ModelConfig": backend="jax", ) + def _build_model_config(self, loglik_op: Any) -> "Config": + """Build a validated :class:`Config` for use by :class:`~hssm.rl.rlssm.RLSSM`. + + Converts this :class:`RLSSMConfig` to a :class:`ModelConfig`, then + delegates to :meth:`Config._build_model_config` using the pre-built + differentiable Op as ``loglik``. + + Parameters + ---------- + loglik_op + The differentiable pytensor Op produced by + :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`. + + Returns + ------- + Config + A fully validated :class:`Config` ready to pass to + :meth:`~hssm.base.HSSMBase.__init__`. + """ + mc = self.to_model_config() + return Config._build_model_config( + self.model_name, + "approx_differentiable", + mc, + None, + loglik_op, + ) + @dataclass class ModelConfig: From 626301d61a22334ae782933b34cfd7e2b8a04f21 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 14:01:15 -0400 Subject: [PATCH 10/44] Refactor HSSM class to remove Config inheritance and add initialization parameters for model configuration --- src/hssm/hssm.py | 61 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 57 insertions(+), 4 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 2a5431bf..af3e5b03 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -9,13 +9,18 @@ import logging from copy import deepcopy from inspect import isclass +from os import PathLike from typing import TYPE_CHECKING, Any, Callable, Literal from typing import cast as typing_cast +import bambi as bmb import numpy as np +import pandas as pd import pymc as pm +from hssm._types import LoglikKind, SupportedModels from hssm.defaults import ( + INITVAL_JITTER_SETTINGS, MissingDataNetwork, missing_data_networks_suffix, ) @@ -30,11 +35,9 @@ ) from .base import HSSMBase -from .config import Config +from .config import Config, ModelConfig if TYPE_CHECKING: - from os import PathLike - from pytensor.graph.op import Op _logger = logging.getLogger("hssm") @@ -75,7 +78,7 @@ def __get__(self, instance, owner): # noqa: D105 return self.fget(owner) -class HSSM(HSSMBase, Config): +class HSSM(HSSMBase): """The basic Hierarchical Sequential Sampling Model (HSSM) class. Parameters @@ -248,6 +251,56 @@ class HSSM(HSSMBase, Config): The jitter value for the initial values. """ + def __init__( + self, + data: pd.DataFrame, + model: SupportedModels | str = "ddm", + choices: list[int] | None = None, + include: list[dict[str, Any] | Any] | None = None, + model_config: ModelConfig | dict | None = None, + loglik: ( + str | PathLike | Callable | pm.Distribution | type[pm.Distribution] | None + ) = None, + loglik_kind: LoglikKind | None = None, + p_outlier: float | dict | bmb.Prior | None = 0.05, + lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0), + global_formula: str | None = None, + link_settings: Literal["log_logit"] | None = None, + prior_settings: Literal["safe"] | None = "safe", + extra_namespace: dict[str, Any] | None = None, + missing_data: bool | float = False, + deadline: bool | str = False, + loglik_missing_data: (str | PathLike | Callable | None) = None, + process_initvals: bool = True, + initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], + **kwargs: Any, + ) -> None: + # ===== save/load serialisation ===== + self._init_args = self._store_init_args(locals(), kwargs) + + # Build typed Config via factory + config = Config._build_model_config( + model, loglik_kind, model_config, choices, loglik + ) + + super().__init__( + data=data, + model_config=config, + include=include, + p_outlier=p_outlier, + lapse=lapse, + global_formula=global_formula, + link_settings=link_settings, + prior_settings=prior_settings, + extra_namespace=extra_namespace, + missing_data=missing_data, + deadline=deadline, + loglik_missing_data=loglik_missing_data, + process_initvals=process_initvals, + initval_jitter=initval_jitter, + **kwargs, + ) + def _make_model_distribution(self) -> type[pm.Distribution]: """Make a pm.Distribution for the model.""" ### Logic for different types of likelihoods: From 6c13443f0cf63667a1e24e7432884470ec5651f1 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 14:01:31 -0400 Subject: [PATCH 11/44] Refactor RLSSM class to remove RLSSMConfig inheritance and streamline model configuration handling --- src/hssm/rl/rlssm.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 0a685f99..8f22b40a 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -36,7 +36,7 @@ from ..base import HSSMBase -class RLSSM(HSSMBase, RLSSMConfig): +class RLSSM(HSSMBase): """Reinforcement Learning Sequential Sampling Model. Combines a reinforcement learning (RL) process with a sequential sampling @@ -122,6 +122,9 @@ def __init__( initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], **kwargs: Any, ) -> None: + # ===== save/load serialisation ===== + self._init_args = self._store_init_args(locals(), kwargs) + # Validate config (ensures ssm_logp_func is present, etc.) rlssm_config.validate() @@ -174,20 +177,15 @@ def __init__( extra_fields=list(rlssm_config.extra_fields or []), ) - # Delegate ModelConfig construction to RLSSMConfig, which already owns - # all the required fields (response, list_params, choices, bounds, …). - mc = rlssm_config.to_model_config() + # Build a typed Config instance via RLSSMConfig's own factory method. + # The differentiable Op is passed so Config.validate() is satisfied; + # loglik_kind="approx_differentiable" reflects that the Op has gradients. + config = rlssm_config._build_model_config(loglik_op) super().__init__( data=data, - model=rlssm_config.model_name, + model_config=config, include=include, - model_config=mc, - # Pass the Op as loglik so Config.validate() is satisfied. - # loglik_kind="approx_differentiable" reflects that the Op is - # differentiable (gradients flow through its VJP). - loglik=loglik_op, - loglik_kind="approx_differentiable", p_outlier=p_outlier, lapse=lapse, link_settings=link_settings, From 66296f0872fa0b5ed2fa7a40ca2691ccd72182dd Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 14:33:13 -0400 Subject: [PATCH 12/44] Refactor Config and RLSSMConfig classes to use concrete types in method signatures --- src/hssm/config.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index ba4500da..d5510303 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -279,12 +279,12 @@ def get_defaults( @classmethod def _build_model_config( cls, - model: "SupportedModels | str", - loglik_kind: "LoglikKind | None", - model_config: "ModelConfig | dict | None", - choices: "list[int] | None", + model: SupportedModels | str, + loglik_kind: LoglikKind | None, + model_config: ModelConfig | dict | None, + choices: list[int] | None, loglik: Any = None, - ) -> "Config": + ) -> Config: """Build and return a validated Config for standard HSSM models. Resolves defaults, normalizes dict/ModelConfig overrides, applies @@ -388,7 +388,7 @@ def from_defaults( return Config.from_defaults(model_name, loglik_kind) @classmethod - def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> "RLSSMConfig": + def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> RLSSMConfig: """ Create RLSSMConfig from a configuration dictionary. @@ -518,7 +518,7 @@ def get_defaults( return default_val, self.bounds.get(param) - def to_config(self) -> "Config": + def to_config(self) -> Config: """Convert to standard Config for compatibility with HSSM. This method transforms the RLSSM configuration into a standard Config @@ -569,7 +569,7 @@ def to_config(self) -> "Config": loglik=self.loglik, ) - def to_model_config(self) -> "ModelConfig": + def to_model_config(self) -> ModelConfig: """Build a :class:`ModelConfig` from this :class:`RLSSMConfig`. All fields are sourced from ``self``; the backend is fixed to ``"jax"`` @@ -590,7 +590,7 @@ def to_model_config(self) -> "ModelConfig": backend="jax", ) - def _build_model_config(self, loglik_op: Any) -> "Config": + def _build_model_config(self, loglik_op: Any) -> Config: """Build a validated :class:`Config` for use by :class:`~hssm.rl.rlssm.RLSSM`. Converts this :class:`RLSSMConfig` to a :class:`ModelConfig`, then From 801f235c755c55a36145a63dbbcb86d3948f4f09 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 14:40:36 -0400 Subject: [PATCH 13/44] Update Config class parameter types for choices to improve type safety --- src/hssm/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index d5510303..e0aacf2d 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -216,7 +216,7 @@ def update_choices(self, choices: tuple[int, ...] | None) -> None: Parameters ---------- - choices : tuple[int, ...] + choices : tuple[int, ...] | None A tuple of choices. """ if choices is None: @@ -282,7 +282,7 @@ def _build_model_config( model: SupportedModels | str, loglik_kind: LoglikKind | None, model_config: ModelConfig | dict | None, - choices: list[int] | None, + choices: list[int] | tuple[int, ...] | None, loglik: Any = None, ) -> Config: """Build and return a validated Config for standard HSSM models. From 607874cc65f24f6eda34bf856fe9813be2409a44 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 14:42:47 -0400 Subject: [PATCH 14/44] Update choices method to accept a tuple for model_config.choices --- src/hssm/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/base.py b/src/hssm/base.py index d80cb590..af5afd57 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -440,7 +440,7 @@ def choices(self): def choices(self, value): """Set the model_config.choices value (deprecated).""" self._deprecation_warn("choices") - self.model_config.choices = value + self.model_config.choices = tuple(value) if value is not None else None @property def model_name(self): From 9e25a32667cbc1a8f57ce76176956db8dfcb8194 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 15:34:14 -0400 Subject: [PATCH 15/44] Add tests for model configuration handling and choices logic in Config --- tests/test_config.py | 51 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index ba8429d0..9d7f6f26 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,8 +1,11 @@ -import numpy as np +import logging + import pytest +import numpy as np -import hssm from hssm.config import Config, ModelConfig +import hssm + hssm.set_floatX("float32") @@ -79,3 +82,47 @@ def test_update_config(): assert v_prior.name == "Normal" assert v_bounds == (-np.inf, np.inf) + + +class TestConfigBuildModelConfigExtraLogic: + def test_build_model_config_dict_with_choices_conflict(self, caplog): + # model 'ddm' has defaults in hssm.defaults; use a minimal dict override + model_config = { + "response": ("rt", "response"), + "list_params": ["v", "a"], + "choices": (0, 1), + } + # provide a different choices argument — should log that model_config wins + with caplog.at_level(logging.INFO): + cfg = Config._build_model_config("ddm", None, model_config, choices=[1, 0]) + + assert isinstance(cfg, Config) + assert "choices list provided in both model_config" in caplog.text + + def test_build_model_config_modelconfig_adds_choices(self): + # Create a ModelConfig without choices and pass choices argument + mc = ModelConfig(response=("rt", "response"), list_params=["v"], choices=None) + cfg = Config._build_model_config("ddm", None, mc, choices=[0, 1]) + # choices should be applied to resulting Config + assert tuple(cfg.choices) == (0, 1) + + def test_build_model_config_uses_ssms_model_config(self, monkeypatch): + # Simulate an external ssms_model_config entry for a model not in SupportedModels + fake_model = "external_ssm" + fake_choices = [2, 3] + + # Monkeypatch the ssms_model_config mapping in the module + import hssm.config as cfgmod + + monkeypatch.setitem( + cfgmod.ssms_model_config, fake_model, {"choices": fake_choices} + ) + + # Build config with model not in SupportedModels (string) and no choices arg + # provide a loglik_kind so from_defaults does not raise for custom model + # Monkeypatch Config.validate to skip strict checks for this synthetic case + monkeypatch.setattr(Config, "validate", lambda self: None) + result = Config._build_model_config( + fake_model, "analytical", None, choices=None + ) + assert tuple(result.choices) == tuple(fake_choices) From 2b2d66b67c37c4114db351389cbe92ea955ac3f8 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 16:18:57 -0400 Subject: [PATCH 16/44] Enhance HSSMBase initialization with safe default for constructor arguments and explicit error handling for missing snapshot --- src/hssm/base.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/hssm/base.py b/src/hssm/base.py index af5afd57..6f66342c 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -237,6 +237,12 @@ def __init__( self.prior_settings = prior_settings self.missing_data_value = -999.0 + # Store a safe default for the constructor-arguments snapshot so that + # pickling / save-load cannot raise AttributeError if a subclass forgets + # to call `_store_init_args(locals(), kwargs)` early. Subclasses are + # still expected to overwrite this with the real snapshot. + self._init_args: dict[str, Any] = {} + # Set up additional namespace for formula evaluation additional_namespace = transformations_namespace.copy() if extra_namespace is not None: @@ -1764,6 +1770,15 @@ def __getstate__(self): A dictionary containing the constructor arguments under the key 'constructor_args'. """ + # Provide a clear error when the initialization snapshot is missing or + # empty. This makes the contract explicit and avoids an AttributeError + # that is easy to miss for subclasses that forget to capture init args. + if not hasattr(self, "_init_args") or not self._init_args: + raise RuntimeError( + "Model state missing initialization snapshot; ensure subclasses " + "call _store_init_args(locals(), kwargs) early in __init__" + ) + state = {"constructor_args": self._init_args} return state From d5b9d80c820a2d6ea581b4fd7d0440f2c0b848ef Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 16:33:19 -0400 Subject: [PATCH 17/44] Update model_config validation to check for non-null choices --- src/hssm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index e0aacf2d..42198cdf 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -295,7 +295,7 @@ def _build_model_config( if model_config is not None: has_choices = ( isinstance(model_config, dict) - and "choices" in model_config + and model_config.get("choices") is not None # choicesn not none in dict or isinstance(model_config, ModelConfig) and model_config.choices is not None ) From 9bf18ea8472bd2a2adb1d855f9d055894bf0d2d7 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 11 Mar 2026 16:45:07 -0400 Subject: [PATCH 18/44] Refactor HSSM distribution method to use typed model_config attributes and avoid deprecated proxy properties --- src/hssm/hssm.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index af3e5b03..f22f7c78 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -419,11 +419,16 @@ def _make_model_distribution(self) -> type[pm.Distribution]: if isinstance(param.prior, np.ndarray) } - assert self.list_params is not None, "list_params should be set" + # Use the typed `model_config` attributes directly to avoid triggering + # DeprecationWarnings from the deprecated proxy properties. + _list_params = self.model_config.list_params + assert _list_params is not None, "list_params should be set" # for type checker + rv_name = getattr(self.model_config, "rv") or self.model_config.model_name + return make_distribution( - rv=getattr(self.model_config, "rv", None) or self.model_name, + rv=rv_name, loglik=self.loglik, - list_params=self.list_params, + list_params=_list_params, bounds=self.bounds, lapse=self.lapse, extra_fields=( From 9af3e95783beec274487e229cc529000f5d7ec37 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 12:14:03 -0400 Subject: [PATCH 19/44] Update test cases to use tuples for choices in model configuration --- tests/test_config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 9d7f6f26..7cf32527 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -102,14 +102,14 @@ def test_build_model_config_dict_with_choices_conflict(self, caplog): def test_build_model_config_modelconfig_adds_choices(self): # Create a ModelConfig without choices and pass choices argument mc = ModelConfig(response=("rt", "response"), list_params=["v"], choices=None) - cfg = Config._build_model_config("ddm", None, mc, choices=[0, 1]) + cfg = Config._build_model_config("ddm", None, mc, choices=(0, 1)) # choices should be applied to resulting Config - assert tuple(cfg.choices) == (0, 1) + assert cfg.choices == (0, 1) def test_build_model_config_uses_ssms_model_config(self, monkeypatch): # Simulate an external ssms_model_config entry for a model not in SupportedModels fake_model = "external_ssm" - fake_choices = [2, 3] + fake_choices = (2, 3) # Monkeypatch the ssms_model_config mapping in the module import hssm.config as cfgmod @@ -125,4 +125,4 @@ def test_build_model_config_uses_ssms_model_config(self, monkeypatch): result = Config._build_model_config( fake_model, "analytical", None, choices=None ) - assert tuple(result.choices) == tuple(fake_choices) + assert result.choices == fake_choices From c2e09d96c1f0349345922b6098f4453e402accba Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 13:21:26 -0400 Subject: [PATCH 20/44] Refactor RLSSM to utilize model_config for list_params and loglik, enhancing type safety and validation --- src/hssm/rl/rlssm.py | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 8f22b40a..e268a16f 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -25,7 +25,8 @@ if TYPE_CHECKING: from pytensor.graph import Op -from hssm.config import RLSSMConfig + +from hssm.config import Config, RLSSMConfig from hssm.defaults import ( INITVAL_JITTER_SETTINGS, ) @@ -217,27 +218,37 @@ def _make_model_distribution(self) -> type[pm.Distribution]: RLSSM and ``missing_data`` / ``deadline`` are rejected in ``__init__`` before this method is ever reached. """ - # Build params_is_trialwise in the same order as self.list_params so the - # length always matches the list_params= argument passed to make_distribution. - # p_outlier is a scalar mixture weight (not trialwise); every other RLSSM - # parameter is trialwise (the Op receives one value per trial). - assert self.list_params is not None, "list_params should be set by HSSMBase" - params_is_trialwise = [name != "p_outlier" for name in self.list_params] + list_params = self.model_config.list_params + assert list_params is not None, "model_config.list_params must be set" + assert isinstance(list_params, list), ( + "model_config.list_params must be a list" + ) # for type checker + + # p_outlier is a scalar mixture weight (not trialwise); every other + # RLSSM parameter is trialwise (the Op receives one value per trial). + params_is_trialwise = [name != "p_outlier" for name in list_params] + extra_fields = self.model_config.extra_fields or [] extra_fields_data = ( None - if not self.extra_fields - else [self.data[field].to_numpy(copy=True) for field in self.extra_fields] + if not extra_fields + else [self.data[field].to_numpy(copy=True) for field in extra_fields] ) - # self.loglik was set to the pytensor Op built in __init__; cast to - # narrow the inherited union type so make_distribution's type-checker - # accepts it without a runtime penalty. - loglik_op = cast("Callable[..., Any] | Op", self.loglik) + # The differentiable pytensor Op was stored on the validated model_config + # during __init__ as its `loglik`; ensure it's present and cast for typing. + assert self.model_config.loglik is not None, "model_config.loglik must be set" + loglik_op = cast("Callable[..., Any] | Op", self.model_config.loglik) + + # `model_config` is typed as BaseModelConfig on the base class; cast + # to `Config` here so static checkers understand `rv` exists. + cfg = cast("Config", self.model_config) + rv_name = cfg.rv or cfg.model_name + return make_distribution( - rv=self.model_name, + rv=rv_name, loglik=loglik_op, - list_params=self.list_params, + list_params=list_params, bounds=self.bounds, lapse=self.lapse, extra_fields=extra_fields_data, From 1aa19f28562b79a7a8acaaf2bf19d0f4f043f89d Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 13:21:44 -0400 Subject: [PATCH 21/44] Fix typo in comment regarding model_config choices validation --- src/hssm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index 42198cdf..8e44b59c 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -295,7 +295,7 @@ def _build_model_config( if model_config is not None: has_choices = ( isinstance(model_config, dict) - and model_config.get("choices") is not None # choicesn not none in dict + and model_config.get("choices") is not None # choices not none in dict or isinstance(model_config, ModelConfig) and model_config.choices is not None ) From 0ea0998217da43604bea91e58c176261e13d57ca Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 13:21:58 -0400 Subject: [PATCH 22/44] Refactor RLSSM tests to access model configuration attributes directly, ensuring consistency with updated model_config structure --- tests/test_rlssm.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index 6cd67e90..f5ef6b03 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -108,7 +108,7 @@ def test_rlssm_init(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: """Basic RLSSM initialisation should succeed and return an RLSSM instance.""" model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) assert isinstance(model, RLSSM) - assert model.model_name == "rldm_test" + assert model.model_config.model_name == "rldm_test" def test_rlssm_panel_attrs(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: @@ -220,10 +220,12 @@ def test_rlssm_params_is_trialwise_aligned( ) -> None: """params_is_trialwise must align with list_params (same length, p_outlier=False).""" model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) - assert model.list_params is not None - params_is_trialwise = [name != "p_outlier" for name in model.list_params] - assert len(params_is_trialwise) == len(model.list_params) - for name, is_tw in zip(model.list_params, params_is_trialwise): + assert model.model_config.list_params is not None + params_is_trialwise = [ + name != "p_outlier" for name in model.model_config.list_params + ] + assert len(params_is_trialwise) == len(model.model_config.list_params) + for name, is_tw in zip(model.model_config.list_params, params_is_trialwise): if name == "p_outlier": assert not is_tw, "p_outlier must be non-trialwise" else: From 8f526f4935cc12d74ccd616a36339c13e681748e Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 13:22:04 -0400 Subject: [PATCH 23/44] Update attribute comparison in compare_hssm_class_attributes to use model_config for model_name --- tests/test_save_load.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 489dca12..614427f7 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -13,7 +13,9 @@ def compare_hssm_class_attributes(model_a, model_b): b = np.array([type(v) for k, v in model_b._init_args.items()]) assert (a == b).all(), "Init arg types not the same" assert (model_a.data).equals(model_b.data), "Data not the same" - assert model_a.model_name == model_b.model_name, "Model name not the same" + assert model_a.model_config.model_name == model_b.model_config.model_name, ( + "Model name not the same" + ) assert model_a.pymc_model._repr_latex_() == model_b.pymc_model._repr_latex_(), ( "Latex representation of model not the same" ) From 5dd68a5b6abd2a584f64d31226966e6e5ab4bf81 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 13:22:10 -0400 Subject: [PATCH 24/44] Update test assertions to access model configuration attributes directly --- tests/test_hssm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 8309bf01..ffc58b38 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -123,9 +123,9 @@ def test_custom_model(data_ddm): loglik_kind="analytical", ) - assert model.model_name == "custom" - assert model.loglik_kind == "analytical" - assert model.list_params == ["v", "a", "z", "t", "p_outlier"] + assert model.model_config.model_name == "custom" + assert model.model_config.loglik_kind == "analytical" + assert model.model_config.list_params == ["v", "a", "z", "t", "p_outlier"] @pytest.mark.slow From 7054ccdff46b0b2e6f4e59dcc1fa28fef259ef87 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 13:53:51 -0400 Subject: [PATCH 25/44] Refactor model configuration normalization to streamline choices handling and improve logging --- src/hssm/config.py | 87 +++++++++++++++++++++++++++++----------------- 1 file changed, 56 insertions(+), 31 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index 8e44b59c..91e0609c 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -26,6 +26,7 @@ _logger = logging.getLogger("hssm") + # ====== Centralized RLSSM defaults ===== DEFAULT_SSM_OBSERVED_DATA = ["rt", "response"] DEFAULT_RLSSM_OBSERVED_DATA = ["rt", "response"] @@ -293,37 +294,7 @@ def _build_model_config( config = cls.from_defaults(model, loglik_kind) if model_config is not None: - has_choices = ( - isinstance(model_config, dict) - and model_config.get("choices") is not None # choices not none in dict - or isinstance(model_config, ModelConfig) - and model_config.choices is not None - ) - if choices is not None: - if has_choices: - _logger.info( - "choices list provided in both model_config and " - "as an argument directly." - " Using the one provided in model_config. \n" - "We recommend providing choices in model_config." - ) - else: - if isinstance(model_config, dict): - model_config = {**model_config, "choices": choices} - else: - model_config_dict = { - k: getattr(model_config, k) - for k in model_config.__dataclass_fields__ - if getattr(model_config, k) is not None - } - model_config_dict["choices"] = choices - model_config = model_config_dict - - final_config = ( - model_config - if isinstance(model_config, ModelConfig) - else ModelConfig(**model_config) - ) + final_config = _normalize_model_config_with_choices(model_config, choices) config.update_config(final_config) else: @@ -631,3 +602,57 @@ class ModelConfig: backend: Literal["jax", "pytensor"] | None = None rv: RandomVariable | None = None extra_fields: list[str] | None = None + + +def _normalize_model_config_with_choices( + model_config: "ModelConfig" | dict[str, Any], + choices: list[int] | tuple[int, ...] | None, +) -> "ModelConfig": + """Normalize a user-supplied model_config and apply choices. + + Returns a fresh :class:`ModelConfig` instance and does not mutate the + caller's objects. If both ``model_config`` and ``choices`` are provided + and ``model_config`` already contains ``choices``, the value from + ``model_config`` wins (and a log entry is emitted). + """ + has_choices = ( + isinstance(model_config, dict) + and model_config.get("choices") is not None + or isinstance(model_config, ModelConfig) + and model_config.choices is not None + ) + + # Guard: no explicit choices -> return ModelConfig from input + if choices is None: + return ( + model_config + if isinstance(model_config, ModelConfig) + else ModelConfig(**model_config) + ) + + # Guard: model_config already carries choices -> prefer it + if has_choices: + _logger.info( + "choices list provided in both model_config and " + "as an argument directly. Using the one provided in " + "model_config.\nWe recommend providing choices in model_config." + ) + return ( + model_config + if isinstance(model_config, ModelConfig) + else ModelConfig(**model_config) + ) + + # Apply provided choices without mutating caller's object + if isinstance(model_config, dict): + mc = model_config.copy() + mc["choices"] = choices + return ModelConfig(**mc) + + mc = { + k: getattr(model_config, k) + for k in model_config.__dataclass_fields__ + if getattr(model_config, k) is not None + } + mc["choices"] = choices + return ModelConfig(**mc) From 5e816bc36cba537b7caf3faf267f24166113cd4c Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 13:56:38 -0400 Subject: [PATCH 26/44] Refactor choices handling in Config class to improve clarity and logging --- src/hssm/config.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index 91e0609c..2b1b05c6 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -297,25 +297,27 @@ def _build_model_config( final_config = _normalize_model_config_with_choices(model_config, choices) config.update_config(final_config) - else: - if model in get_args(SupportedModels): - if choices is not None: - _logger.info( - "Model string is in SupportedModels." - " Ignoring choices arguments." - ) - else: - if choices is not None: - config.update_choices(choices) - elif model in ssms_model_config: - config.update_choices(ssms_model_config[model]["choices"]) - _logger.info( - "choices argument passed as None, " - "but found %s in ssms-simulators. " - "Using choices, from ssm-simulators configs: %s", - model, - ssms_model_config[model]["choices"], - ) + # No model_config provided: apply `choices` when appropriate. + # If caller passed a SupportedModels string, ignore explicit `choices`. + if model in get_args(SupportedModels) and choices is not None: + _logger.info( + "Model string is in SupportedModels. Ignoring choices arguments." + ) + + # If model is not a supported built-in, prefer explicit choices or + # fall back to ssms-simulators lookup when available. + if model not in get_args(SupportedModels): + if choices is not None: + config.update_choices(choices) + elif model in ssms_model_config: + config.update_choices(ssms_model_config[model]["choices"]) + _logger.info( + "choices argument passed as None, " + "but found %s in ssms-simulators. " + "Using choices, from ssm-simulators configs: %s", + model, + ssms_model_config[model]["choices"], + ) config.update_loglik(loglik) config.validate() From 9f6a7ef8c7749189aea722d7de4397262729f5a8 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 14:20:14 -0400 Subject: [PATCH 27/44] Refactor _normalize_model_config_with_choices to improve input handling and choices normalization --- src/hssm/config.py | 57 +++++++++++++++++++--------------------------- 1 file changed, 23 insertions(+), 34 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index 2b1b05c6..4b75ba3f 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -617,44 +617,33 @@ def _normalize_model_config_with_choices( and ``model_config`` already contains ``choices``, the value from ``model_config`` wins (and a log entry is emitted). """ - has_choices = ( - isinstance(model_config, dict) - and model_config.get("choices") is not None - or isinstance(model_config, ModelConfig) - and model_config.choices is not None - ) - - # Guard: no explicit choices -> return ModelConfig from input + # Normalize input to a mutable dict so we can coerce and avoid mutating + # the caller's objects. Build a fresh ModelConfig from that dict. + if isinstance(model_config, ModelConfig): + mc: dict[str, Any] = { + k: getattr(model_config, k) for k in model_config.__dataclass_fields__ + } + else: + mc = model_config.copy() + + # Coerce any existing choices on the input to a tuple for immutability + if mc.get("choices") is not None: + mc["choices"] = tuple(mc["choices"]) + + # If caller didn't provide an explicit `choices` argument, return the + # normalized ModelConfig built from the input (fresh instance). if choices is None: - return ( - model_config - if isinstance(model_config, ModelConfig) - else ModelConfig(**model_config) - ) + return ModelConfig(**{k: v for k, v in mc.items() if v is not None}) - # Guard: model_config already carries choices -> prefer it - if has_choices: + # Caller provided choices; prefer the one embedded in model_config if + # present, otherwise apply the provided value (coerced to tuple). + if mc.get("choices") is not None: _logger.info( "choices list provided in both model_config and " "as an argument directly. Using the one provided in " - "model_config.\nWe recommend providing choices in model_config." - ) - return ( - model_config - if isinstance(model_config, ModelConfig) - else ModelConfig(**model_config) + "model_config. We recommend providing choices in model_config." ) + return ModelConfig(**{k: v for k, v in mc.items() if v is not None}) - # Apply provided choices without mutating caller's object - if isinstance(model_config, dict): - mc = model_config.copy() - mc["choices"] = choices - return ModelConfig(**mc) - - mc = { - k: getattr(model_config, k) - for k in model_config.__dataclass_fields__ - if getattr(model_config, k) is not None - } - mc["choices"] = choices - return ModelConfig(**mc) + mc["choices"] = tuple(choices) + return ModelConfig(**{k: v for k, v in mc.items() if v is not None}) From 49415ab45f0b2127da768d7725f6887604cca693 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 16:16:06 -0400 Subject: [PATCH 28/44] Refactor likelihood callable construction to simplify logic and enhance clarity --- src/hssm/hssm.py | 47 ++++++++++++++++++----------------------------- 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index f22f7c78..d7cb3141 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -334,40 +334,29 @@ def _make_model_distribution(self) -> type[pm.Distribution]: for param_name, param in self.params.items() if param_name != "p_outlier" ] - # params_is_trialwise: extends the base list with extra_fields - # (always trialwise). Used for vmap construction in - # make_likelihood_callable and for assemble_callables, where - # dist_params includes extra_fields flattened in. - params_is_trialwise = list(params_is_trialwise_base) + params_is_trialwise = params_is_trialwise_base.copy() if self.extra_fields is not None: params_is_trialwise += [True for _ in self.extra_fields] + # endregion + + # region Build the likelihood callable using guard clauses + backend = self.model_config.backend + kwargs = { + "loglik": loglik_callable, + "loglik_kind": loglik_kind, + "backend": backend, + } + if loglik_kind == "approx_differentiable" and backend == "jax": + kwargs["params_is_reg"] = params_is_trialwise # type: ignore + likelihood_callable = make_likelihood_callable(**kwargs) # type: ignore + # endregion - if self.loglik_kind == "approx_differentiable": - if self.model_config.backend == "jax": - likelihood_callable = make_likelihood_callable( - loglik=loglik_callable, - loglik_kind="approx_differentiable", - backend="jax", - params_is_reg=params_is_trialwise, - ) - else: - likelihood_callable = make_likelihood_callable( - loglik=loglik_callable, - loglik_kind="approx_differentiable", - backend=self.model_config.backend, - ) - else: - likelihood_callable = make_likelihood_callable( - loglik=loglik_callable, - loglik_kind=self.loglik_kind, - backend=self.model_config.backend, - ) - - self.loglik = likelihood_callable + # Update the authoritative `model_config` with the resolved callable + typing_cast("Config", self.model_config).update_loglik(likelihood_callable) + resolved_loglik = likelihood_callable - # Make the callable for missing data - # And assemble it with the callable for the likelihood + # Missing-data network: build and assemble the missing-data callable if self.missing_data_network != MissingDataNetwork.NONE: if self.missing_data_network == MissingDataNetwork.OPN: params_only = False From 4452f360555043be7c2d0a34478f126b67db4810 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 16:16:58 -0400 Subject: [PATCH 29/44] Refactor _make_model_distribution to utilize model_config for loglik and loglik_kind --- src/hssm/hssm.py | 39 ++++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index d7cb3141..c445c904 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -302,30 +302,27 @@ def __init__( ) def _make_model_distribution(self) -> type[pm.Distribution]: - """Make a pm.Distribution for the model.""" - ### Logic for different types of likelihoods: - # -`analytical` and `blackbox`: - # loglik should be a `pm.Distribution`` or a Python callable (any arbitrary - # function). - # - `approx_differentiable`: - # In addition to `pm.Distribution` and any arbitrary function, it can also - # be an str (which we will download from hugging face) or a Pathlike - # which we will download and make a distribution. - - # If user has already provided a log-likelihood function as a distribution - # Use it directly as the distribution - if isclass(self.loglik) and issubclass(self.loglik, pm.Distribution): - return self.loglik - - # At this point, loglik should not be a type[Distribution] and should be set - - assert self.loglik is not None, "loglik should be set" - assert self.loglik_kind is not None, "loglik_kind should be set" - assert not (isclass(self.loglik) and issubclass(self.loglik, pm.Distribution)) + """Make a pm.Distribution for the model. + + This method avoids using the deprecated proxy properties on ``self`` and + instead reads and updates the authoritative ``self.model_config``. + """ + # Read raw inputs from the typed model_config + raw_loglik = self.model_config.loglik + if isclass(raw_loglik) and issubclass( + typing_cast("type[pm.Distribution]", raw_loglik), pm.Distribution + ): + return typing_cast("type[pm.Distribution]", raw_loglik) + loglik_callable = typing_cast( - "Op | Callable[..., Any] | PathLike | str", self.loglik + "Op | Callable[..., Any] | PathLike | str", raw_loglik ) + # Prefer the typed value in model_config for loglik_kind + loglik_kind = typing_cast("LoglikKind", self.model_config.loglik_kind) + + # region Determine the trialwise nature of parameters for use in loglik and + # missing-data callables # params_is_trialwise_base: one entry per model param (excluding # p_outlier). Used for graph-level broadcasting in logp() and # make_distribution, where dist_params does not include extra_fields. From c34e56253be156139de50f0ca3383b9462465858 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 16:17:37 -0400 Subject: [PATCH 30/44] Fix formatting in HSSM class for consistency in likelihood callable parameters --- src/hssm/hssm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index c445c904..c8ccc615 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -345,8 +345,8 @@ def _make_model_distribution(self) -> type[pm.Distribution]: "backend": backend, } if loglik_kind == "approx_differentiable" and backend == "jax": - kwargs["params_is_reg"] = params_is_trialwise # type: ignore - likelihood_callable = make_likelihood_callable(**kwargs) # type: ignore + kwargs["params_is_reg"] = params_is_trialwise # type: ignore + likelihood_callable = make_likelihood_callable(**kwargs) # type: ignore # endregion # Update the authoritative `model_config` with the resolved callable From 3e86974077f43bdabda2fdfdb85f1bd48e448e66 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 16:18:28 -0400 Subject: [PATCH 31/44] Fix formatting in HSSM class for consistency in likelihood callable parameters --- src/hssm/hssm.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index c8ccc615..07e9c8ff 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -363,30 +363,32 @@ def _make_model_distribution(self) -> type[pm.Distribution]: params_only = None if self.loglik_missing_data is None: + # Use the model name from the typed config self.loglik_missing_data = ( - self.model_name + self.model_config.model_name + missing_data_networks_suffix[self.missing_data_network] + ".onnx" ) - backend_tmp: Literal["pytensor", "jax", "other"] | None = ( - "jax" - if self.model_config.backend != "pytensor" - else self.model_config.backend - ) + if self.model_config.backend != "pytensor": + backend_tmp: Literal["pytensor", "jax", "other"] | None = "jax" + else: + backend_tmp = self.model_config.backend missing_data_callable = make_missing_data_callable( self.loglik_missing_data, backend_tmp, params_is_trialwise, params_only ) self.loglik_missing_data = missing_data_callable - self.loglik = assemble_callables( - self.loglik, + assembled = assemble_callables( + resolved_loglik, self.loglik_missing_data, params_only, has_deadline=self.deadline, params_is_trialwise=params_is_trialwise, ) + typing_cast("Config", self.model_config).update_loglik(assembled) + resolved_loglik = assembled if self.missing_data: _logger.info( From 4a5aefc23369d44c3f3b8d1b0bde144f5ec127ac Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 12 Mar 2026 16:18:47 -0400 Subject: [PATCH 32/44] Refactor HSSM class to use typed model_config attributes directly and resolve loglik --- src/hssm/hssm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 07e9c8ff..b273f0e5 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -407,15 +407,14 @@ def _make_model_distribution(self) -> type[pm.Distribution]: if isinstance(param.prior, np.ndarray) } - # Use the typed `model_config` attributes directly to avoid triggering - # DeprecationWarnings from the deprecated proxy properties. + # Use the typed `model_config` attributes directly _list_params = self.model_config.list_params assert _list_params is not None, "list_params should be set" # for type checker rv_name = getattr(self.model_config, "rv") or self.model_config.model_name return make_distribution( rv=rv_name, - loglik=self.loglik, + loglik=resolved_loglik, list_params=_list_params, bounds=self.bounds, lapse=self.lapse, From cdc776329f7b85a2835d8d9761eea7b012e36b01 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 13 Mar 2026 09:41:31 -0400 Subject: [PATCH 33/44] Restore make_model_dist in HSSM --- src/hssm/hssm.py | 108 ++++++++++++++++++++++++++--------------------- 1 file changed, 60 insertions(+), 48 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index b273f0e5..01999bad 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -302,27 +302,30 @@ def __init__( ) def _make_model_distribution(self) -> type[pm.Distribution]: - """Make a pm.Distribution for the model. - - This method avoids using the deprecated proxy properties on ``self`` and - instead reads and updates the authoritative ``self.model_config``. - """ - # Read raw inputs from the typed model_config - raw_loglik = self.model_config.loglik - if isclass(raw_loglik) and issubclass( - typing_cast("type[pm.Distribution]", raw_loglik), pm.Distribution - ): - return typing_cast("type[pm.Distribution]", raw_loglik) - + """Make a pm.Distribution for the model.""" + ### Logic for different types of likelihoods: + # -`analytical` and `blackbox`: + # loglik should be a `pm.Distribution`` or a Python callable (any arbitrary + # function). + # - `approx_differentiable`: + # In addition to `pm.Distribution` and any arbitrary function, it can also + # be an str (which we will download from hugging face) or a Pathlike + # which we will download and make a distribution. + + # If user has already provided a log-likelihood function as a distribution + # Use it directly as the distribution + if isclass(self.loglik) and issubclass(self.loglik, pm.Distribution): + return self.loglik + + # At this point, loglik should not be a type[Distribution] and should be set + + assert self.loglik is not None, "loglik should be set" + assert self.loglik_kind is not None, "loglik_kind should be set" + assert not (isclass(self.loglik) and issubclass(self.loglik, pm.Distribution)) loglik_callable = typing_cast( - "Op | Callable[..., Any] | PathLike | str", raw_loglik + "Op | Callable[..., Any] | PathLike | str", self.loglik ) - # Prefer the typed value in model_config for loglik_kind - loglik_kind = typing_cast("LoglikKind", self.model_config.loglik_kind) - - # region Determine the trialwise nature of parameters for use in loglik and - # missing-data callables # params_is_trialwise_base: one entry per model param (excluding # p_outlier). Used for graph-level broadcasting in logp() and # make_distribution, where dist_params does not include extra_fields. @@ -331,29 +334,40 @@ def _make_model_distribution(self) -> type[pm.Distribution]: for param_name, param in self.params.items() if param_name != "p_outlier" ] + # params_is_trialwise: extends the base list with extra_fields - params_is_trialwise = params_is_trialwise_base.copy() + # (always trialwise). Used for vmap construction in + # make_likelihood_callable and for assemble_callables, where + # dist_params includes extra_fields flattened in. + params_is_trialwise = list(params_is_trialwise_base) if self.extra_fields is not None: params_is_trialwise += [True for _ in self.extra_fields] - # endregion - - # region Build the likelihood callable using guard clauses - backend = self.model_config.backend - kwargs = { - "loglik": loglik_callable, - "loglik_kind": loglik_kind, - "backend": backend, - } - if loglik_kind == "approx_differentiable" and backend == "jax": - kwargs["params_is_reg"] = params_is_trialwise # type: ignore - likelihood_callable = make_likelihood_callable(**kwargs) # type: ignore - # endregion - # Update the authoritative `model_config` with the resolved callable - typing_cast("Config", self.model_config).update_loglik(likelihood_callable) - resolved_loglik = likelihood_callable + if self.loglik_kind == "approx_differentiable": + if self.model_config.backend == "jax": + likelihood_callable = make_likelihood_callable( + loglik=loglik_callable, + loglik_kind="approx_differentiable", + backend="jax", + params_is_reg=params_is_trialwise, + ) + else: + likelihood_callable = make_likelihood_callable( + loglik=loglik_callable, + loglik_kind="approx_differentiable", + backend=self.model_config.backend, + ) + else: + likelihood_callable = make_likelihood_callable( + loglik=loglik_callable, + loglik_kind=self.loglik_kind, + backend=self.model_config.backend, + ) + + self.loglik = likelihood_callable - # Missing-data network: build and assemble the missing-data callable + # Make the callable for missing data + # And assemble it with the callable for the likelihood if self.missing_data_network != MissingDataNetwork.NONE: if self.missing_data_network == MissingDataNetwork.OPN: params_only = False @@ -363,32 +377,30 @@ def _make_model_distribution(self) -> type[pm.Distribution]: params_only = None if self.loglik_missing_data is None: - # Use the model name from the typed config self.loglik_missing_data = ( - self.model_config.model_name + self.model_name + missing_data_networks_suffix[self.missing_data_network] + ".onnx" ) - if self.model_config.backend != "pytensor": - backend_tmp: Literal["pytensor", "jax", "other"] | None = "jax" - else: - backend_tmp = self.model_config.backend + backend_tmp: Literal["pytensor", "jax", "other"] | None = ( + "jax" + if self.model_config.backend != "pytensor" + else self.model_config.backend + ) missing_data_callable = make_missing_data_callable( self.loglik_missing_data, backend_tmp, params_is_trialwise, params_only ) self.loglik_missing_data = missing_data_callable - assembled = assemble_callables( - resolved_loglik, + self.loglik = assemble_callables( + self.loglik, self.loglik_missing_data, params_only, has_deadline=self.deadline, params_is_trialwise=params_is_trialwise, ) - typing_cast("Config", self.model_config).update_loglik(assembled) - resolved_loglik = assembled if self.missing_data: _logger.info( @@ -410,11 +422,11 @@ def _make_model_distribution(self) -> type[pm.Distribution]: # Use the typed `model_config` attributes directly _list_params = self.model_config.list_params assert _list_params is not None, "list_params should be set" # for type checker - rv_name = getattr(self.model_config, "rv") or self.model_config.model_name + rv_name = getattr(self.model_config, "rv", None) or self.model_config.model_name return make_distribution( rv=rv_name, - loglik=resolved_loglik, + loglik=self.loglik, list_params=_list_params, bounds=self.bounds, lapse=self.lapse, From e3cbcb7f50eca6ab4b08fbb918df0aae0b48691b Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 13 Mar 2026 10:20:10 -0400 Subject: [PATCH 34/44] Remove deprecated properties and methods from HSSMBase class --- src/hssm/base.py | 135 +++++------------------------------------------ 1 file changed, 13 insertions(+), 122 deletions(-) diff --git a/src/hssm/base.py b/src/hssm/base.py index 6f66342c..765542c4 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -8,7 +8,6 @@ import datetime import logging -import warnings from abc import ABC, abstractmethod from copy import deepcopy from os import PathLike @@ -264,10 +263,19 @@ def __init__( self.model_config: BaseModelConfig = model_config # endregion - # Previously scalar shortcuts (e.g. `self.list_params`) were set here. - # These are now provided as deprecated proxy properties that forward to - # the authoritative `self.model_config` object. See property definitions - # below. + # region ===== Set up shortcuts so old code will work ====== + self.response = ( + list(self.model_config.response) + if self.model_config.response is not None + else None + ) + self.list_params = self.model_config.list_params + self.choices = self.model_config.choices # type: ignore[assignment] + self.model_name = self.model_config.model_name + self.loglik = self.model_config.loglik + self.loglik_kind = self.model_config.loglik_kind + self.extra_fields = self.model_config.extra_fields + # endregion self._validate_choices() @@ -382,123 +390,6 @@ def _make_model_distribution(self) -> type[pm.Distribution]: """ ... - def _deprecation_warn(self, name: str) -> None: - """Emit a DeprecationWarning advising to use the typed config. - - Parameters - ---------- - name - Attribute name being deprecated. - """ - warnings.warn( - f"`{name}` is deprecated; use `self.model_config.{name}` instead.", - DeprecationWarning, - stacklevel=2, - ) - - @property - def response(self): - """Deprecated proxy for `self.model_config.response`. - - Returns the list of observed response column names or None. - """ - self._deprecation_warn("response") - return ( - list(self.model_config.response) - if self.model_config.response is not None - else None - ) - - @response.setter - def response(self, value): - """Set the model_config.response value (deprecated). - - Converts the assigned value to a list or None before assignment. - """ - self._deprecation_warn("response") - self.model_config.response = list(value) if value is not None else None - - @property - def list_params(self): - """Deprecated proxy for `self.model_config.list_params`. - - Returns the parameter name list used by the likelihood. - """ - self._deprecation_warn("list_params") - return self.model_config.list_params - - @list_params.setter - def list_params(self, value): - """Set the model_config.list_params value (deprecated).""" - self._deprecation_warn("list_params") - self.model_config.list_params = value - - @property - def choices(self): - """Deprecated proxy for `self.model_config.choices`. - - Returns a tuple of valid response choices. - """ - self._deprecation_warn("choices") - return self.model_config.choices - - @choices.setter - def choices(self, value): - """Set the model_config.choices value (deprecated).""" - self._deprecation_warn("choices") - self.model_config.choices = tuple(value) if value is not None else None - - @property - def model_name(self): - """Deprecated proxy for `self.model_config.model_name`.""" - self._deprecation_warn("model_name") - return self.model_config.model_name - - @model_name.setter - def model_name(self, value): - """Set the model_config.model_name value (deprecated).""" - self._deprecation_warn("model_name") - self.model_config.model_name = value - - @property - def loglik(self): - """Deprecated proxy for `self.model_config.loglik`. - - Returns the configured log-likelihood callable or object. - """ - self._deprecation_warn("loglik") - return self.model_config.loglik - - @loglik.setter - def loglik(self, value): - """Set the model_config.loglik value (deprecated).""" - self._deprecation_warn("loglik") - self.model_config.loglik = value - - @property - def loglik_kind(self): - """Deprecated proxy for `self.model_config.loglik_kind`.""" - self._deprecation_warn("loglik_kind") - return self.model_config.loglik_kind - - @loglik_kind.setter - def loglik_kind(self, value): - """Set the model_config.loglik_kind value (deprecated).""" - self._deprecation_warn("loglik_kind") - self.model_config.loglik_kind = value - - @property - def extra_fields(self): - """Deprecated proxy for `self.model_config.extra_fields`.""" - self._deprecation_warn("extra_fields") - return self.model_config.extra_fields - - @extra_fields.setter - def extra_fields(self, value): - """Set the model_config.extra_fields value (deprecated).""" - self._deprecation_warn("extra_fields") - self.model_config.extra_fields = value - def _fix_scalar_deterministic_dims(self) -> None: """Fix dims metadata for scalar deterministics. From 7e481e050e4510735b504ed1ae5ce75a54880aca Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 13 Mar 2026 10:36:52 -0400 Subject: [PATCH 35/44] Enhance HSSMBase class to prevent overwriting _init_args if already set in subclasses and exclude additional internal names from locals() snapshots during re-instantiation. --- src/hssm/base.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/hssm/base.py b/src/hssm/base.py index 765542c4..2bd5c320 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -238,9 +238,13 @@ def __init__( # Store a safe default for the constructor-arguments snapshot so that # pickling / save-load cannot raise AttributeError if a subclass forgets - # to call `_store_init_args(locals(), kwargs)` early. Subclasses are - # still expected to overwrite this with the real snapshot. - self._init_args: dict[str, Any] = {} + # to call `_store_init_args(locals(), kwargs)` early. Subclasses are + # still expected to overwrite this with the real snapshot. However, + # do not overwrite if a subclass already set `_init_args` prior to + # calling `super().__init__()` (the subclass may capture its + # constructor args before delegating to the base class). + if not hasattr(self, "_init_args"): + self._init_args: dict[str, Any] = {} # Set up additional namespace for formula evaluation additional_namespace = transformations_namespace.copy() @@ -468,7 +472,10 @@ def _store_init_args( A mapping of parameter names to their values, suitable for reconstructing the instance via ``cls(**init_args)``. """ - result = {k: v for k, v in local_vars.items() if k not in ("self", "kwargs")} + # Exclude internal names that appear in locals() snapshots and are not + # valid constructor parameters when re-instantiating the class. + exclude_keys = {"self", "kwargs", "__class__"} + result = {k: v for k, v in local_vars.items() if k not in exclude_keys} result.update(extra_kwargs) return result From 3432bca1f1d2be40f9aaa5a235f8338c500fe4a2 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 13 Mar 2026 11:11:47 -0400 Subject: [PATCH 36/44] Clarify model_config parameter documentation in HSSMBase class to specify required fields and improve readability. --- src/hssm/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/hssm/base.py b/src/hssm/base.py index 2bd5c320..81ce34da 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -115,11 +115,11 @@ class HSSMBase(ABC, DataValidatorMixin, MissingDataMixin): model. If left unspecified, defaults will be used for all parameter specifications. Defaults to None. model_config - A fully-initialised :class:`~hssm.config.BaseModelConfig` instance - (e.g. :class:`~hssm.config.Config` or - :class:`~hssm.config.RLSSMConfig`) produced by the subclass before - calling ``super().__init__``. All likelihood, parameter, and data - information is drawn from this object. + A fully initialised :class:`~hssm.config.BaseModelConfig` instance + (typically :class:`~hssm.config.Config`) produced by the subclass + before calling ``super().__init__``. All likelihood, parameter, and + data information used by :class:`HSSMBase` is drawn from this object, + and it must provide populated ``loglik`` and ``list_params`` fields. p_outlier : optional The fixed lapse probability or the prior distribution of the lapse probability. Defaults to a fixed value of 0.05. When `None`, the lapse probability will not From 31bd6f14c81a550a47e27fcd38126f1ad3f01146 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 13 Mar 2026 11:19:08 -0400 Subject: [PATCH 37/44] Enhance HSSMBase class documentation to clarify filtering of internal names in parameter mapping for safe unpickling. --- src/hssm/base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/hssm/base.py b/src/hssm/base.py index 81ce34da..95d5c97c 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -471,6 +471,13 @@ def _store_init_args( dict[str, Any] A mapping of parameter names to their values, suitable for reconstructing the instance via ``cls(**init_args)``. + + Notes + ----- + The implementation filters out internal names that commonly appear in + ``locals()`` snapshots (for example, ``__class__`` and ``kwargs``) so + that the returned mapping is safe to pass back to the class + constructor during unpickling. """ # Exclude internal names that appear in locals() snapshots and are not # valid constructor parameters when re-instantiating the class. From 296810b0fcdc30b0915303ae8d52e2d506de6670 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 13 Mar 2026 11:19:17 -0400 Subject: [PATCH 38/44] Update model_config parameter documentation in HSSM class to support BaseModelConfig instance and clarify usage of dict for configuration. --- src/hssm/hssm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 01999bad..e86b4f7e 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -104,9 +104,12 @@ class HSSM(HSSMBase): model. If left unspecified, defaults will be used for all parameter specifications. Defaults to None. model_config : optional - A dictionary containing the model configuration information. If None is - provided, defaults will be used if there are any. Defaults to None. - Fields for this `dict` are usually: + A :class:`~hssm.config.BaseModelConfig` / :class:`~hssm.config.Config` + instance or a ``dict`` with model configuration information. The + constructor accepts a typed ``ModelConfig`` or a plain ``dict``; when a + ``dict`` is provided the library will build a typed :class:`Config` + via the factory function. If ``None`` is provided, defaults will be + used where available. Fields for this config are usually: - `"list_params"`: a list of parameters indicating the parameters of the model. The order in which the parameters are specified in this list is important. From 95779bc72fe3e12177b9dfa21a032e9430153f70 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 13 Mar 2026 12:06:46 -0400 Subject: [PATCH 39/44] Add test to validate external model config fallback in _build_model_config --- tests/test_config.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 7cf32527..4094c354 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -107,6 +107,11 @@ def test_build_model_config_modelconfig_adds_choices(self): assert cfg.choices == (0, 1) def test_build_model_config_uses_ssms_model_config(self, monkeypatch): + # High-level view of the test: ensures that when a model name is not in the built-in + # SupportedModels and no choices argument is passed, _build_model_config will consult + # the external ssms_model_config registry and use its defaults (here, the choices tuple). + # The monkeypatch fixture isolates the change and will be undone after the test. + # Simulate an external ssms_model_config entry for a model not in SupportedModels fake_model = "external_ssm" fake_choices = (2, 3) @@ -114,15 +119,23 @@ def test_build_model_config_uses_ssms_model_config(self, monkeypatch): # Monkeypatch the ssms_model_config mapping in the module import hssm.config as cfgmod + # Emulate an external package registering defaults for external_ssm. + # Ensures `_build_model_config` will consult `ssms_model_config` + # when the model name isn't in SupportedModels. monkeypatch.setitem( cfgmod.ssms_model_config, fake_model, {"choices": fake_choices} ) - # Build config with model not in SupportedModels (string) and no choices arg - # provide a loglik_kind so from_defaults does not raise for custom model - # Monkeypatch Config.validate to skip strict checks for this synthetic case - monkeypatch.setattr(Config, "validate", lambda self: None) + # Build config with model not in SupportedModels and no choices arg. + # Provide a minimal ModelConfig and a dummy `loglik` so + # `Config.validate()` runs (loglik is required) while still + # exercising the ssms-simulators choices fallback. + mc = ModelConfig(response=("rt", "response"), list_params=["v"], choices=None) result = Config._build_model_config( - fake_model, "analytical", None, choices=None + fake_model, + "analytical", + mc, + choices=None, + loglik=(lambda *a, **k: None), # required so Config.validate() passes ) assert result.choices == fake_choices From 37ea9be5dd07b19d92e2a0a1d232357df9552112 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 17 Mar 2026 10:42:48 -0400 Subject: [PATCH 40/44] Update sampling parameters in test_rlssm_sample_smoke for speed --- tests/test_rlssm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index f5ef6b03..0e60bfb2 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -309,5 +309,7 @@ def test_rlssm_pymc_model(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> def test_rlssm_sample_smoke(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: """Minimal sampling run should return an InferenceData object.""" model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) - trace = model.sample(draws=2, tune=2, chains=1, cores=1) + trace = model.sample( + draws=4, tune=50, chains=1, cores=1, sampler="numpyro", target_accept=0.9 + ) assert trace is not None From 9a37dd81998c495b3c01fd186bdcd4d904d96453 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 17 Mar 2026 10:43:11 -0400 Subject: [PATCH 41/44] Add RLSSM quickstart notebook for model instantiation and sampling demonstration --- docs/tutorials/rlssm_quickstart.ipynb | 357 ++++++++++++++++++++++++++ 1 file changed, 357 insertions(+) create mode 100644 docs/tutorials/rlssm_quickstart.ipynb diff --git a/docs/tutorials/rlssm_quickstart.ipynb b/docs/tutorials/rlssm_quickstart.ipynb new file mode 100644 index 00000000..343c1c58 --- /dev/null +++ b/docs/tutorials/rlssm_quickstart.ipynb @@ -0,0 +1,357 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1b9b429d", + "metadata": {}, + "source": [ + "# RLSSM Quickstart: Instantiation, Model Building, and Sampling\n", + "\n", + "This notebook provides a minimal end-to-end demonstration of the `RLSSM` class:\n", + "\n", + "1. **Load** a balanced-panel two-armed bandit dataset\n", + "2. **Define** an annotated learning function and the angle SSM log-likelihood\n", + "3. **Configure** and **instantiate** an `RLSSM` model\n", + "4. **Inspect** the built Bambi / PyMC model\n", + "5. **Run** a minimal 2-draw sampling smoke test\n", + "\n", + "For a full treatment — simulating data, hierarchical formulas, meaningful sampling, and posterior visualization — see:\n", + "- [rlssm_tutorial.ipynb](rlssm_tutorial.ipynb)\n", + "- [add_custom_rlssm_model.ipynb](add_custom_rlssm_model.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "bf38d7f7", + "metadata": {}, + "source": [ + "## 1. Imports and Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d764731", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import hssm\n", + "from hssm import RLSSM, RLSSMConfig\n", + "from hssm.distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx\n", + "from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise\n", + "from hssm.utils import annotate_function\n", + "\n", + "# RLSSM requires float32 throughout (JAX default).\n", + "hssm.set_floatX(\"float32\", update_jax=True)" + ] + }, + { + "cell_type": "markdown", + "id": "df12303f", + "metadata": {}, + "source": [ + "## 2. Load the Dataset\n", + "\n", + "We use a small synthetic two-armed bandit dataset from the HSSM test fixtures. \n", + "It is a **balanced panel**: every participant has the same number of trials. \n", + "Columns: `participant_id`, `trial_id`, `rt`, `response`, `feedback`.\n", + "\n", + "> **Note:** You can also generate data with\n", + "> [`ssm-simulators`](https://github.com/AlexanderFengler/ssm-simulators).\n", + "> See `rlssm_tutorial.ipynb` for an example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2ef5f6e", + "metadata": {}, + "outputs": [], + "source": [ + "# Path relative to docs/tutorials/ when running inside the HSSM repo.\n", + "_fixture_path = Path(\"../../tests/fixtures/rldm_data.npy\")\n", + "raw = np.load(_fixture_path, allow_pickle=True).item()\n", + "data = pd.DataFrame(raw[\"data\"])\n", + "\n", + "n_participants = data[\"participant_id\"].nunique()\n", + "n_trials = len(data) // n_participants\n", + "\n", + "print(data.head())\n", + "print(f\"\\nParticipants: {n_participants} | Trials per participant: {n_trials}\")" + ] + }, + { + "cell_type": "markdown", + "id": "8c310290", + "metadata": {}, + "source": [ + "## 3. Define the Learning Process\n", + "\n", + "The RL learning process is a JAX function that, given a subject's trial sequence, computes\n", + "the trial-wise drift rate `v` via a Q-learning update rule. \n", + "\n", + "`annotate_function` attaches `.inputs`, `.outputs`, and (optionally) `.computed` metadata\n", + "that the RLSSM likelihood builder uses to automatically construct the input matrix for the\n", + "decision process.\n", + "\n", + "- **inputs** — columns that the function reads (free parameters + data columns)\n", + "- **outputs** — what the function produces (here: `v`, the drift rate)\n", + "\n", + "Here we annotate the built-in `compute_v_subject_wise` function, which implements a simple\n", + "Rescorla-Wagner Q-learning update for a two-armed bandit task." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbcea122", + "metadata": {}, + "outputs": [], + "source": [ + "compute_v_annotated = annotate_function(\n", + " inputs=[\"rl_alpha\", \"scaler\", \"response\", \"feedback\"],\n", + " outputs=[\"v\"],\n", + ")(compute_v_subject_wise)\n", + "\n", + "print(\"Learning function inputs :\", compute_v_annotated.inputs)\n", + "print(\"Learning function outputs:\", compute_v_annotated.outputs)" + ] + }, + { + "cell_type": "markdown", + "id": "7a03305a", + "metadata": {}, + "source": [ + "## 4. Define the Decision (SSM) Log-Likelihood\n", + "\n", + "The decision process uses the **angle model** likelihood, loaded from an ONNX file.\n", + "`make_jax_matrix_logp_funcs_from_onnx` returns a JAX callable that accepts a\n", + "2-D matrix whose columns are `[v, a, z, t, theta, rt, response]` and returns\n", + "per-trial log-probabilities.\n", + "\n", + "We then annotate that callable so the builder knows:\n", + "- which columns the matrix contains (`inputs`)\n", + "- that `v` itself is *computed* by the learning function (not a free parameter)\n", + "\n", + "The ONNX file is loaded from the local test fixture when running inside the HSSM\n", + "repository; otherwise it is downloaded from the HuggingFace Hub (`franklab/HSSM`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60bbc036", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the local fixture when available; fall back to HuggingFace download.\n", + "_local_onnx = Path(\"../../tests/fixtures/angle.onnx\").resolve()\n", + "_onnx_model = str(_local_onnx) if _local_onnx.exists() else \"angle.onnx\"\n", + "\n", + "_angle_logp_jax = make_jax_matrix_logp_funcs_from_onnx(model=_onnx_model)\n", + "\n", + "angle_logp_func = annotate_function(\n", + " inputs=[\"v\", \"a\", \"z\", \"t\", \"theta\", \"rt\", \"response\"],\n", + " outputs=[\"logp\"],\n", + " computed={\"v\": compute_v_annotated},\n", + ")(_angle_logp_jax)\n", + "\n", + "print(\"SSM logp inputs :\", angle_logp_func.inputs)\n", + "print(\"SSM logp outputs:\", angle_logp_func.outputs)\n", + "print(\"Computed deps :\", list(angle_logp_func.computed.keys()))" + ] + }, + { + "cell_type": "markdown", + "id": "cf8f5b63", + "metadata": {}, + "source": [ + "## 5. Configure the Model with `RLSSMConfig`\n", + "\n", + "`RLSSMConfig` collects all the information the RLSSM class needs:\n", + "\n", + "| Field | Purpose |\n", + "|-------|---------|\n", + "| `model_name` | Identifier string for the configuration |\n", + "| `decision_process` | Name of the SSM (e.g. `\"angle\"`) |\n", + "| `list_params` | Ordered list of *free* parameters to sample |\n", + "| `params_default` | Starting / default values for each parameter |\n", + "| `bounds` | Prior bounds for each parameter |\n", + "| `learning_process` | Dict mapping computed param name → annotated learning function |\n", + "| `extra_fields` | Extra data columns required by the learning function |\n", + "| `ssm_logp_func` | Annotated JAX callable for the decision-process likelihood |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4beba1bc", + "metadata": {}, + "outputs": [], + "source": [ + "rlssm_config = RLSSMConfig(\n", + " model_name=\"rlssm_angle_quickstart\",\n", + " loglik_kind=\"approx_differentiable\",\n", + " decision_process=\"angle\",\n", + " decision_process_loglik_kind=\"approx_differentiable\",\n", + " learning_process_loglik_kind=\"blackbox\",\n", + " list_params=[\"rl_alpha\", \"scaler\", \"a\", \"theta\", \"t\", \"z\"],\n", + " params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5],\n", + " bounds={\n", + " \"rl_alpha\": (0.0, 1.0),\n", + " \"scaler\": (0.0, 10.0),\n", + " \"a\": (0.1, 3.0),\n", + " \"theta\": (-0.1, 0.1),\n", + " \"t\": (0.001, 1.0),\n", + " \"z\": (0.1, 0.9),\n", + " },\n", + " learning_process={\"v\": compute_v_annotated},\n", + " response=[\"rt\", \"response\"],\n", + " choices=[0, 1],\n", + " extra_fields=[\"feedback\"],\n", + " ssm_logp_func=angle_logp_func,\n", + ")\n", + "\n", + "print(\"Model name :\", rlssm_config.model_name)\n", + "print(\"Free params :\", rlssm_config.list_params)" + ] + }, + { + "cell_type": "markdown", + "id": "924ee4c7", + "metadata": {}, + "source": [ + "## 6. Instantiate the `RLSSM` Model\n", + "\n", + "Passing `data` and `rlssm_config` to `RLSSM`:\n", + "\n", + "- validates the balanced-panel requirement\n", + "- builds a differentiable PyTensor Op that chains the RL learning step and the\n", + " angle log-likelihood\n", + "- constructs the Bambi / PyMC model internally\n", + "\n", + "Note that `v` (the drift rate) is *not* a free parameter — it is computed inside\n", + "the Op by the Q-learning update and therefore does not appear in `model.params`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f8da79a", + "metadata": {}, + "outputs": [], + "source": [ + "model = RLSSM(data=data, rlssm_config=rlssm_config)\n", + "\n", + "assert isinstance(model, RLSSM)\n", + "print(\"Model type :\", type(model).__name__)\n", + "print(\"Participants :\", model.n_participants)\n", + "print(\"Trials/subj :\", model.n_trials)\n", + "print(\"Free parameters :\", list(model.params.keys()))\n", + "assert \"rl_alpha\" in model.params, \"rl_alpha must be a free parameter\"\n", + "assert \"v\" not in model.params, \"v is computed, not a free parameter\"\n", + "model" + ] + }, + { + "cell_type": "markdown", + "id": "f7f39940", + "metadata": {}, + "source": [ + "## 7. Inspect the Built Model\n", + "\n", + "After construction, `model.model` exposes the underlying **Bambi model** and\n", + "`model.pymc_model` exposes the **PyMC model** context — useful for debugging\n", + "or customizing priors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0558ad4", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=== Bambi model ===\")\n", + "print(model.model)\n", + "\n", + "print(\"\\n=== PyMC model ===\")\n", + "print(model.pymc_model)" + ] + }, + { + "cell_type": "markdown", + "id": "f4e50110", + "metadata": {}, + "source": [ + "## 8. Sampling\n", + "\n", + "A minimal sampling run — 2 draws, 2 tuning steps, 1 chain — confirms that the full\n", + "computational graph (Q-learning scan → angle logp → NUTS gradient) is wired correctly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96ce3238", + "metadata": {}, + "outputs": [], + "source": [ + "trace = model.sample(draws=2, tune=2, chains=1, cores=1, sampler=\"numpyro\", target_accept=0.9)\n", + "\n", + "assert trace is not None\n", + "print(trace)" + ] + }, + { + "cell_type": "markdown", + "id": "a784a468", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook showed how to:\n", + "\n", + "1. Load a balanced-panel dataset (`rldm_data.npy`)\n", + "2. Annotate a Q-learning function with `annotate_function`\n", + "3. Load the angle ONNX likelihood and annotate it so the builder can assemble the input matrix\n", + "4. Define an `RLSSMConfig` and pass it to `RLSSM`\n", + "5. Confirm model structure (free params, Bambi / PyMC objects)\n", + "6. Run a 2-draw sampling smoke test that returns an `arviz.InferenceData` object\n", + "\n", + "**Next steps:**\n", + "\n", + "- For a full tutorial (simulate data, hierarchical formulas, `1 000`-draw posterior\n", + " sampling, `arviz` plots) see [`rlssm_tutorial.ipynb`](rlssm_tutorial.ipynb).\n", + "- To add a custom RL model see [`add_custom_rlssm_model.ipynb`](add_custom_rlssm_model.ipynb)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hssm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 43ec65251a2aa5e75656b17fb70349ac70c313a6 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 17 Mar 2026 10:43:29 -0400 Subject: [PATCH 42/44] Add RLSSM Quickstart tutorial to navigation and plugins --- mkdocs.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mkdocs.yml b/mkdocs.yml index 66ab3e10..286ef08c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -43,6 +43,7 @@ nav: - Hierarchical Variational Inference: tutorials/variational_inference_hierarchical.ipynb - Using HSSM low-level API directly with PyMC: tutorials/pymc.ipynb - Reinforcement Learning - Sequential Sampling Models (RLSSM): tutorials/rlssm_tutorial.ipynb + - RLSSM Quickstart: tutorials/rlssm_quickstart.ipynb - Add custom RLSSM models: tutorials/add_custom_rlssm_model.ipynb - Custom models: tutorials/jax_callable_contribution_onnx_example.ipynb - Custom models from onnx files: tutorials/blackbox_contribution_onnx_example.ipynb @@ -91,6 +92,7 @@ plugins: - tutorials/hssm_tutorial_workshop_2.ipynb - tutorials/add_custom_rlssm_model.ipynb - tutorials/rlssm_tutorial.ipynb + - tutorials/rlssm_quickstart.ipynb - tutorials/lapse_prob_and_dist.ipynb - tutorials/plotting.ipynb - tutorials/scientific_workflow_hssm.ipynb From 7f1e6ff44373a91f94633ff68e282b52f36af64c Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 17 Mar 2026 12:11:17 -0400 Subject: [PATCH 43/44] Remove redundant next steps and streamline summary in RLSSM quickstart notebook --- docs/tutorials/rlssm_quickstart.ipynb | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/docs/tutorials/rlssm_quickstart.ipynb b/docs/tutorials/rlssm_quickstart.ipynb index 343c1c58..e4ddc763 100644 --- a/docs/tutorials/rlssm_quickstart.ipynb +++ b/docs/tutorials/rlssm_quickstart.ipynb @@ -323,13 +323,7 @@ "3. Load the angle ONNX likelihood and annotate it so the builder can assemble the input matrix\n", "4. Define an `RLSSMConfig` and pass it to `RLSSM`\n", "5. Confirm model structure (free params, Bambi / PyMC objects)\n", - "6. Run a 2-draw sampling smoke test that returns an `arviz.InferenceData` object\n", - "\n", - "**Next steps:**\n", - "\n", - "- For a full tutorial (simulate data, hierarchical formulas, `1 000`-draw posterior\n", - " sampling, `arviz` plots) see [`rlssm_tutorial.ipynb`](rlssm_tutorial.ipynb).\n", - "- To add a custom RL model see [`add_custom_rlssm_model.ipynb`](add_custom_rlssm_model.ipynb)." + "6. Run a 2-draw sampling smoke test that returns an `arviz.InferenceData` object" ] } ], From e604406575fa4d96889a7604e5675a8895cb66b7 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 18 Mar 2026 15:49:11 -0400 Subject: [PATCH 44/44] Refactor RLSSM class to use model_config instead of rlssm_config for consistency --- src/hssm/rl/rlssm.py | 24 ++++++------ tests/test_rlssm.py | 80 +++++++++++++++----------------------- tests/test_rlssm_config.py | 22 +++++------ 3 files changed, 55 insertions(+), 71 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index e268a16f..c76a0a81 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -47,7 +47,7 @@ class RLSSM(HSSMBase): The likelihood is built via :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op` from the annotated - SSM function stored in *rlssm_config.ssm_logp_func*. This produces a + SSM function stored in *model_config.ssm_logp_func*. This produces a differentiable pytensor ``Op`` that is passed directly to :func:`~hssm.distribution_utils.make_distribution`, superseding the ``loglik`` / ``loglik_kind`` dispatching used by :class:`~hssm.hssm.HSSM`. @@ -56,12 +56,12 @@ class RLSSM(HSSMBase): ---------- data : pd.DataFrame Trial-level data. Must contain at least the response columns - specified in *rlssm_config* (typically ``"rt"`` and ``"response"``), + specified in *model_config* (typically ``"rt"`` and ``"response"``), a participant identifier column (default ``"participant_id"``), and - any extra fields listed in *rlssm_config.extra_fields*. + any extra fields listed in *model_config.extra_fields*. The data **must** form a balanced panel: every participant must have the same number of trials. - rlssm_config : RLSSMConfig + model_config : RLSSMConfig Full configuration for the RLSSM model. Must have ``ssm_logp_func`` set to the annotated JAX SSM log-likelihood function. participant_col : str, optional @@ -108,7 +108,7 @@ class RLSSM(HSSMBase): def __init__( self, data: pd.DataFrame, - rlssm_config: RLSSMConfig, + model_config: RLSSMConfig, participant_col: str = "participant_id", include: list[dict[str, Any] | Any] | None = None, p_outlier: float | dict | bmb.Prior | None = 0.05, @@ -127,7 +127,7 @@ def __init__( self._init_args = self._store_init_args(locals(), kwargs) # Validate config (ensures ssm_logp_func is present, etc.) - rlssm_config.validate() + model_config.validate() # RLSSM reshapes rows into (n_participants, n_trials, ...) by position, # so _rearrange_data (which moves missing/deadline rows to the front) @@ -156,7 +156,7 @@ def __init__( # Store RL-specific state on self BEFORE super().__init__() so that # _make_model_distribution() (called from super) can access them. - self.config = rlssm_config + self.config = model_config self.n_participants = n_participants self.n_trials = n_trials @@ -170,18 +170,18 @@ def __init__( # "p_outlier" to self.list_params, and that mutation must NOT be visible # to the Op's _validate_args_length check at sampling time. loglik_op = make_rl_logp_op( - ssm_logp_func=rlssm_config.ssm_logp_func, + ssm_logp_func=model_config.ssm_logp_func, n_participants=n_participants, n_trials=n_trials, - data_cols=list(rlssm_config.response), # type: ignore[arg-type] - list_params=list(rlssm_config.list_params), # type: ignore[arg-type] - extra_fields=list(rlssm_config.extra_fields or []), + data_cols=list(model_config.response), # type: ignore[arg-type] + list_params=list(model_config.list_params), # type: ignore[arg-type] + extra_fields=list(model_config.extra_fields or []), ) # Build a typed Config instance via RLSSMConfig's own factory method. # The differentiable Op is passed so Config.validate() is satisfied; # loglik_kind="approx_differentiable" reflects that the Op has gradients. - config = rlssm_config._build_model_config(loglik_op) + config = model_config._build_model_config(loglik_op) super().__init__( data=data, diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index 0e60bfb2..98b34181 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -104,16 +104,16 @@ def rlssm_config() -> RLSSMConfig: # --------------------------------------------------------------------------- -def test_rlssm_init(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_init(rldm_data, rlssm_config) -> None: """Basic RLSSM initialisation should succeed and return an RLSSM instance.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) assert isinstance(model, RLSSM) assert model.model_config.model_name == "rldm_test" -def test_rlssm_panel_attrs(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_panel_attrs(rldm_data, rlssm_config) -> None: """n_participants and n_trials should match the fixture data structure.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) n_participants = rldm_data["participant_id"].nunique() n_trials = len(rldm_data) // n_participants @@ -122,35 +122,29 @@ def test_rlssm_panel_attrs(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) - assert model.n_trials == n_trials -def test_rlssm_params_keys(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_params_keys(rldm_data, rlssm_config) -> None: """model.params should contain exactly list_params + p_outlier.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) expected = set(rlssm_config.list_params) | {"p_outlier"} assert set(model.params.keys()) == expected -def test_rlssm_unbalanced_raises( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_unbalanced_raises(rldm_data, rlssm_config) -> None: """Dropping one row should make the panel unbalanced → ValueError.""" unbalanced = rldm_data.iloc[:-1].copy() with pytest.raises(ValueError, match="balanced panels"): - RLSSM(data=unbalanced, rlssm_config=rlssm_config) + RLSSM(data=unbalanced, model_config=rlssm_config) -def test_rlssm_nan_participant_id_raises( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_nan_participant_id_raises(rldm_data, rlssm_config) -> None: """NaN in participant_id column should raise ValueError before groupby silently drops rows.""" nan_data = rldm_data.copy() nan_data.loc[nan_data.index[0], "participant_id"] = float("nan") with pytest.raises(ValueError, match="NaN"): - RLSSM(data=nan_data, rlssm_config=rlssm_config) + RLSSM(data=nan_data, model_config=rlssm_config) -def test_rlssm_missing_ssm_logp_func_raises( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_missing_ssm_logp_func_raises(rldm_data, rlssm_config) -> None: """RLSSMConfig without ssm_logp_func should raise ValueError on init.""" bad_config = RLSSMConfig( model_name="rldm_bad", @@ -168,12 +162,10 @@ def test_rlssm_missing_ssm_logp_func_raises( # ssm_logp_func intentionally omitted → defaults to None ) with pytest.raises(ValueError, match="ssm_logp_func"): - RLSSM(data=rldm_data, rlssm_config=bad_config) + RLSSM(data=rldm_data, model_config=bad_config) -def test_rlssm_unannotated_ssm_logp_func_raises( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_unannotated_ssm_logp_func_raises(rldm_data, rlssm_config) -> None: """A plain callable without @annotate_function attrs should raise ValueError.""" bad_config = RLSSMConfig( model_name="rldm_bad", @@ -191,23 +183,19 @@ def test_rlssm_unannotated_ssm_logp_func_raises( ssm_logp_func=lambda x: x, # callable but no .inputs/.outputs/.computed ) with pytest.raises(ValueError, match="annotate_function"): - RLSSM(data=rldm_data, rlssm_config=bad_config) + RLSSM(data=rldm_data, model_config=bad_config) -def test_rlssm_missing_data_raises( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_missing_data_raises(rldm_data, rlssm_config) -> None: """Passing missing_data!=False should raise ValueError with 'missing_data' in msg.""" with pytest.raises(ValueError, match="missing_data"): - RLSSM(data=rldm_data, rlssm_config=rlssm_config, missing_data=True) + RLSSM(data=rldm_data, model_config=rlssm_config, missing_data=True) -def test_rlssm_deadline_raises( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_deadline_raises(rldm_data, rlssm_config) -> None: """Passing deadline!=False should raise ValueError with 'deadline' in msg.""" with pytest.raises(ValueError, match="deadline"): - RLSSM(data=rldm_data, rlssm_config=rlssm_config, deadline=True) + RLSSM(data=rldm_data, model_config=rlssm_config, deadline=True) # --------------------------------------------------------------------------- @@ -215,11 +203,9 @@ def test_rlssm_deadline_raises( # --------------------------------------------------------------------------- -def test_rlssm_params_is_trialwise_aligned( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_params_is_trialwise_aligned(rldm_data, rlssm_config) -> None: """params_is_trialwise must align with list_params (same length, p_outlier=False).""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) assert model.model_config.list_params is not None params_is_trialwise = [ name != "p_outlier" for name in model.model_config.list_params @@ -232,28 +218,28 @@ def test_rlssm_params_is_trialwise_aligned( assert is_tw, f"{name} must be trialwise" -def test_rlssm_get_prefix(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_get_prefix(rldm_data, rlssm_config) -> None: """_get_prefix must use token-based matching, not substring search. - 'rl_alpha_Intercept' → 'rl_alpha' (underscore-containing RL param) - 'p_outlier_log__' → 'p_outlier' (lapse param via token loop, not substring) - 'a_Intercept' → 'a' (single-token standard param) """ - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) assert model._get_prefix("rl_alpha_Intercept") == "rl_alpha" assert model._get_prefix("p_outlier_log__") == "p_outlier" assert model._get_prefix("a_Intercept") == "a" -def test_rlssm_no_lapse(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_no_lapse(rldm_data, rlssm_config) -> None: """Setting p_outlier=None should remove p_outlier from params.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config, p_outlier=None) + model = RLSSM(data=rldm_data, model_config=rlssm_config, p_outlier=None) assert "p_outlier" not in model.params -def test_rlssm_model_built(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_model_built(rldm_data, rlssm_config) -> None: """The bambi model should be built and the computed param 'v' absent from params.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) assert model.model is not None # rl_alpha is a free (sampled) parameter assert "rl_alpha" in model.params @@ -261,9 +247,7 @@ def test_rlssm_model_built(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) - assert "v" not in model.params -def test_rlssm_extra_fields_are_copies( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: +def test_rlssm_extra_fields_are_copies(rldm_data, rlssm_config) -> None: """extra_fields passed to make_distribution must be independent numpy copies. to_numpy(copy=True) should return a new buffer; if it returned a view, @@ -273,7 +257,7 @@ def test_rlssm_extra_fields_are_copies( from hssm.distribution_utils import make_distribution as real_make_distribution - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) captured: dict = {} def capturing_make_distribution(*args, **kwargs): @@ -294,9 +278,9 @@ def capturing_make_distribution(*args, **kwargs): ) -def test_rlssm_pymc_model(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_pymc_model(rldm_data, rlssm_config) -> None: """pymc_model should be accessible after model construction.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) assert model.pymc_model is not None @@ -306,9 +290,9 @@ def test_rlssm_pymc_model(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> @pytest.mark.slow -def test_rlssm_sample_smoke(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) -> None: +def test_rlssm_sample_smoke(rldm_data, rlssm_config) -> None: """Minimal sampling run should return an InferenceData object.""" - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) trace = model.sample( draws=4, tune=50, chains=1, cores=1, sampler="numpyro", target_accept=0.9 ) diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 143ff117..46859da4 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -313,7 +313,7 @@ def test_to_config_cases( expected_default_priors, raises, ): - rlssm_config = RLSSMConfig( + model_config = RLSSMConfig( model_name="test_model", list_params=list_params, params_default=params_default, @@ -327,14 +327,14 @@ def test_to_config_cases( ) if raises: with pytest.raises(raises): - rlssm_config.to_config() + model_config.to_config() else: - config = rlssm_config.to_config() + config = model_config.to_config() assert config.backend == expected_backend assert config.default_priors == expected_default_priors def test_to_config(self): - rlssm_config = RLSSMConfig( + model_config = RLSSMConfig( model_name="rlwm", description="RLWM model", list_params=["alpha", "beta", "v", "a"], @@ -354,7 +354,7 @@ def test_to_config(self): learning_process_loglik_kind="blackbox", learning_process={}, ) - config = rlssm_config.to_config() + config = model_config.to_config() assert isinstance(config, Config) assert config.model_name == "rlwm" assert config.description == "RLWM model" @@ -378,7 +378,7 @@ def test_to_config(self): } def test_to_config_defaults_backend(self): - rlssm_config = RLSSMConfig( + model_config = RLSSMConfig( model_name="test_model", list_params=["alpha"], params_default=[0.5], @@ -389,11 +389,11 @@ def test_to_config_defaults_backend(self): learning_process_loglik_kind="blackbox", learning_process={}, ) - config = rlssm_config.to_config() + config = model_config.to_config() assert config.backend == "jax" def test_to_config_no_defaults(self): - rlssm_config = RLSSMConfig( + model_config = RLSSMConfig( model_name="test_model", list_params=["alpha", "beta"], params_default=[], @@ -404,11 +404,11 @@ def test_to_config_no_defaults(self): learning_process_loglik_kind="blackbox", learning_process={}, ) - config = rlssm_config.to_config() + config = model_config.to_config() assert config.default_priors == {} def test_to_config_mismatched_defaults_length(self): - rlssm_config = RLSSMConfig( + model_config = RLSSMConfig( model_name="test_model", list_params=["alpha", "beta", "gamma"], params_default=[0.5, 0.3], @@ -423,7 +423,7 @@ def test_to_config_mismatched_defaults_length(self): ValueError, match=r"params_default length \(2\) doesn't match list_params length \(3\)", ): - rlssm_config.to_config() + model_config.to_config() class TestRLSSMConfigLearningProcess: