From aec153133291f65939fb01beeebcd89eb74fdaba Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 18 Mar 2026 13:07:10 -0400 Subject: [PATCH 01/22] Refactor RLSSMConfig methods to simplify parameter handling and remove unused conversion tests --- src/hssm/config.py | 120 ++------------------------------ tests/test_rlssm_config.py | 139 ------------------------------------- 2 files changed, 5 insertions(+), 254 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index 4b75ba3f..a2893660 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -474,122 +474,12 @@ def get_defaults( Returns ------- tuple - A tuple of (default_value, bounds) where: - - default_value is a float or None if not found - - bounds is a tuple (lower, upper) or None if not found + ``(None, bounds)`` — the default prior is intentionally ``None`` so that + ``prior_settings="safe"`` in :class:`~hssm.base.HSSMBase` assigns + priors from bounds rather than fixing parameters to scalar constants. + ``params_default`` stores initialisation values, not prior distributions. """ - # Try to find the parameter in list_params and get its default value - default_val = None - if self.list_params is not None: - try: - param_idx = self.list_params.index(param) - if self.params_default and param_idx < len(self.params_default): - default_val = self.params_default[param_idx] - except ValueError: - # Parameter not in list_params - pass - - return default_val, self.bounds.get(param) - - def to_config(self) -> Config: - """Convert to standard Config for compatibility with HSSM. - - This method transforms the RLSSM configuration into a standard Config - object that can be used with the existing HSSM infrastructure. - - Returns - ------- - Config - A Config object with RLSSM parameters mapped to standard format. - - Notes - ----- - The transformation converts params_default list to default_priors dict, - mapping parameter names to their default values. - """ - # Validate parameter defaults consistency before conversion - if self.params_default and self.list_params: - if len(self.params_default) != len(self.list_params): - raise ValueError( - f"params_default length ({len(self.params_default)}) doesn't " - f"match list_params length ({len(self.list_params)}). " - "This would result in silent data loss during conversion." - ) - - # Transform params_default list to default_priors dict - default_priors = ( - { - param: default - for param, default in zip(self.list_params, self.params_default) - } - if self.list_params and self.params_default - else {} - ) - - return Config( - model_name=self.model_name, - loglik_kind=self.loglik_kind, - response=self.response, - choices=self.choices, - list_params=self.list_params, - description=self.description, - bounds=self.bounds, - default_priors=cast( - "dict[str, float | dict[str, Any] | Any | None]", default_priors - ), - extra_fields=self.extra_fields, - backend=self.backend or "jax", # RLSSM typically uses JAX - loglik=self.loglik, - ) - - def to_model_config(self) -> ModelConfig: - """Build a :class:`ModelConfig` from this :class:`RLSSMConfig`. - - All fields are sourced from ``self``; the backend is fixed to ``"jax"`` - because RLSSM exclusively uses the JAX backend. - - ``default_priors`` is intentionally left empty so the - ``prior_settings="safe"`` mechanism in :class:`~hssm.base.HSSMBase` - assigns sensible priors from bounds rather than fixing every parameter - to a constant scalar. - """ - return ModelConfig( - response=tuple(self.response), # type: ignore[arg-type] - list_params=list(self.list_params), # type: ignore[arg-type] - choices=tuple(self.choices), # type: ignore[arg-type] - default_priors={}, - bounds=self.bounds, - extra_fields=self.extra_fields, - backend="jax", - ) - - def _build_model_config(self, loglik_op: Any) -> Config: - """Build a validated :class:`Config` for use by :class:`~hssm.rl.rlssm.RLSSM`. - - Converts this :class:`RLSSMConfig` to a :class:`ModelConfig`, then - delegates to :meth:`Config._build_model_config` using the pre-built - differentiable Op as ``loglik``. - - Parameters - ---------- - loglik_op - The differentiable pytensor Op produced by - :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`. - - Returns - ------- - Config - A fully validated :class:`Config` ready to pass to - :meth:`~hssm.base.HSSMBase.__init__`. - """ - mc = self.to_model_config() - return Config._build_model_config( - self.model_name, - "approx_differentiable", - mc, - None, - loglik_op, - ) + return None, self.bounds.get(param) @dataclass diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 143ff117..6f0bcbb3 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -287,145 +287,6 @@ def test_get_defaults_cases( assert bounds_val == expected_bounds -class TestRLSSMConfigConversion: - @pytest.mark.parametrize( - "list_params, params_default, backend, expected_backend, expected_default_priors, raises", - [ - ( - ["alpha", "beta", "v", "a"], - [0.5, 0.3, 1.0, 1.5], - "jax", - "jax", - {"alpha": 0.5, "beta": 0.3, "v": 1.0, "a": 1.5}, - None, - ), - (["alpha"], [0.5], None, "jax", {"alpha": 0.5}, None), - (["alpha", "beta"], [], None, "jax", {}, None), - (["alpha", "beta", "gamma"], [0.5, 0.3], None, None, None, ValueError), - ], - ) - def test_to_config_cases( - self, - list_params, - params_default, - backend, - expected_backend, - expected_default_priors, - raises, - ): - rlssm_config = RLSSMConfig( - model_name="test_model", - list_params=list_params, - params_default=params_default, - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - backend=backend, - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - if raises: - with pytest.raises(raises): - rlssm_config.to_config() - else: - config = rlssm_config.to_config() - assert config.backend == expected_backend - assert config.default_priors == expected_default_priors - - def test_to_config(self): - rlssm_config = RLSSMConfig( - model_name="rlwm", - description="RLWM model", - list_params=["alpha", "beta", "v", "a"], - params_default=[0.5, 0.3, 1.0, 1.5], - bounds={ - "alpha": (0.0, 1.0), - "beta": (0.0, 1.0), - "v": (-3.0, 3.0), - "a": (0.3, 2.5), - }, - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - extra_fields=["feedback"], - backend="jax", - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - config = rlssm_config.to_config() - assert isinstance(config, Config) - assert config.model_name == "rlwm" - assert config.description == "RLWM model" - assert config.list_params == ["alpha", "beta", "v", "a"] - assert config.response == ["rt", "response"] - assert config.choices == [0, 1] - assert config.extra_fields == ["feedback"] - assert config.backend == "jax" - assert config.loglik_kind == "approx_differentiable" - assert config.bounds == { - "alpha": (0.0, 1.0), - "beta": (0.0, 1.0), - "v": (-3.0, 3.0), - "a": (0.3, 2.5), - } - assert config.default_priors == { - "alpha": 0.5, - "beta": 0.3, - "v": 1.0, - "a": 1.5, - } - - def test_to_config_defaults_backend(self): - rlssm_config = RLSSMConfig( - model_name="test_model", - list_params=["alpha"], - params_default=[0.5], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - config = rlssm_config.to_config() - assert config.backend == "jax" - - def test_to_config_no_defaults(self): - rlssm_config = RLSSMConfig( - model_name="test_model", - list_params=["alpha", "beta"], - params_default=[], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - config = rlssm_config.to_config() - assert config.default_priors == {} - - def test_to_config_mismatched_defaults_length(self): - rlssm_config = RLSSMConfig( - model_name="test_model", - list_params=["alpha", "beta", "gamma"], - params_default=[0.5, 0.3], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - with pytest.raises( - ValueError, - match=r"params_default length \(2\) doesn't match list_params length \(3\)", - ): - rlssm_config.to_config() - - class TestRLSSMConfigLearningProcess: def test_learning_process(self): config = RLSSMConfig( From a8cd51dfb109280383fd2b3d64f132c3c3e9a80f Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 18 Mar 2026 13:08:01 -0400 Subject: [PATCH 02/22] Fix handling of list_params in HSSMBase to ensure proper conversion from None --- src/hssm/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/hssm/base.py b/src/hssm/base.py index 95d5c97c..02c8c1da 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -273,7 +273,11 @@ def __init__( if self.model_config.response is not None else None ) - self.list_params = self.model_config.list_params + self.list_params = ( + list(self.model_config.list_params) + if self.model_config.list_params is not None + else None + ) self.choices = self.model_config.choices # type: ignore[assignment] self.model_name = self.model_config.model_name self.loglik = self.model_config.loglik From 9c22e26df70b8cffbb06f058c8c322120f548e68 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 18 Mar 2026 13:08:50 -0400 Subject: [PATCH 03/22] Refactor RLSSM to inject model configuration directly, removing unnecessary Config conversion --- src/hssm/rl/rlssm.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index e268a16f..9d3c6a75 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -26,7 +26,7 @@ from pytensor.graph import Op -from hssm.config import Config, RLSSMConfig +from hssm.config import RLSSMConfig from hssm.defaults import ( INITVAL_JITTER_SETTINGS, ) @@ -97,8 +97,9 @@ class RLSSM(HSSMBase): Attributes ---------- - config : RLSSMConfig - The RLSSM configuration object. + model_config : RLSSMConfig + The RLSSM configuration object (stored as ``self.model_config`` on + :class:`~hssm.base.HSSMBase` with the built ``loglik`` Op injected). n_participants : int Number of participants inferred from *data*. n_trials : int @@ -156,14 +157,13 @@ def __init__( # Store RL-specific state on self BEFORE super().__init__() so that # _make_model_distribution() (called from super) can access them. - self.config = rlssm_config self.n_participants = n_participants self.n_trials = n_trials # Build the differentiable pytensor Op from the annotated SSM function. - # This Op supersedes the loglik/loglik_kind workflow: it is passed as - # `loglik` to HSSMBase so Config.validate() is satisfied, and - # _make_model_distribution() uses it directly without any further wrapping. + # This Op supersedes the loglik/loglik_kind workflow: it is stored on + # rlssm_config.loglik so that HSSMBase can access it uniformly via + # self.model_config.loglik, without any Config conversion. # # Fresh list() copies are passed to make_rl_logp_op so the closure inside # captures its own isolated list objects. HSSMBase will later append @@ -178,14 +178,14 @@ def __init__( extra_fields=list(rlssm_config.extra_fields or []), ) - # Build a typed Config instance via RLSSMConfig's own factory method. - # The differentiable Op is passed so Config.validate() is satisfied; - # loglik_kind="approx_differentiable" reflects that the Op has gradients. - config = rlssm_config._build_model_config(loglik_op) + # Inject the built Op and backend directly onto rlssm_config so that + # HSSMBase stores the RLSSMConfig as-is — no Config conversion needed. + rlssm_config.loglik = loglik_op + rlssm_config.backend = "jax" super().__init__( data=data, - model_config=config, + model_config=rlssm_config, include=include, p_outlier=p_outlier, lapse=lapse, @@ -218,10 +218,13 @@ def _make_model_distribution(self) -> type[pm.Distribution]: RLSSM and ``missing_data`` / ``deadline`` are rejected in ``__init__`` before this method is ever reached. """ - list_params = self.model_config.list_params - assert list_params is not None, "model_config.list_params must be set" + # Use self.list_params (managed by HSSMBase, includes p_outlier when + # has_lapse=True) rather than self.model_config.list_params (the original + # config list, never mutated by HSSMBase). + list_params = self.list_params + assert list_params is not None, "list_params must be set" assert isinstance(list_params, list), ( - "model_config.list_params must be a list" + "list_params must be a list" ) # for type checker # p_outlier is a scalar mixture weight (not trialwise); every other @@ -235,15 +238,13 @@ def _make_model_distribution(self) -> type[pm.Distribution]: else [self.data[field].to_numpy(copy=True) for field in extra_fields] ) - # The differentiable pytensor Op was stored on the validated model_config - # during __init__ as its `loglik`; ensure it's present and cast for typing. + # The differentiable pytensor Op was stored on model_config.loglik during + # __init__; ensure it's present and cast for typing. assert self.model_config.loglik is not None, "model_config.loglik must be set" loglik_op = cast("Callable[..., Any] | Op", self.model_config.loglik) - # `model_config` is typed as BaseModelConfig on the base class; cast - # to `Config` here so static checkers understand `rv` exists. - cfg = cast("Config", self.model_config) - rv_name = cfg.rv or cfg.model_name + # RLSSMConfig carries no `rv` field; use model_name as the rv identifier. + rv_name = self.model_config.model_name return make_distribution( rv=rv_name, From 5658834f8ae58780e6bd73455219d5813f7a99e9 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 18 Mar 2026 13:09:11 -0400 Subject: [PATCH 04/22] Update TestRLSSMConfigDefaults to reflect None for default parameters instead of fixed values --- tests/test_rlssm_config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 6f0bcbb3..95a8f447 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -249,12 +249,15 @@ class TestRLSSMConfigDefaults: @pytest.mark.parametrize( "list_params, params_default, bounds, param, expected_default, expected_bounds", [ + # params_default stores initialisation values, NOT priors. + # get_defaults always returns None for the prior so that + # prior_settings="safe" can assign priors from bounds. ( ["alpha", "beta", "gamma"], [0.5, 0.3, 0.2], {"beta": (0.0, 1.0)}, "beta", - 0.3, + None, (0.0, 1.0), ), (["alpha", "beta"], [0.5, 0.3], {"alpha": (0.0, 1.0)}, "gamma", None, None), From 7a294af489ddc25b4775106c0e9300884073b8ad Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 18 Mar 2026 13:48:12 -0400 Subject: [PATCH 05/22] Refactor RLSSM to inject loglik and backend directly into a new RLSSMConfig instance, preserving the original configuration. --- src/hssm/rl/rlssm.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 9d3c6a75..401bab07 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -16,6 +16,7 @@ standard ``loglik`` / ``loglik_kind`` wrapping pipeline. """ +from dataclasses import replace from typing import TYPE_CHECKING, Any, Callable, Literal, cast import bambi as bmb @@ -178,14 +179,14 @@ def __init__( extra_fields=list(rlssm_config.extra_fields or []), ) - # Inject the built Op and backend directly onto rlssm_config so that - # HSSMBase stores the RLSSMConfig as-is — no Config conversion needed. - rlssm_config.loglik = loglik_op - rlssm_config.backend = "jax" + # Build a new RLSSMConfig with the Op and backend injected, leaving + # the caller's object unmodified (dataclasses.replace creates a shallow + # copy with only the specified fields overridden). + model_config = replace(rlssm_config, loglik=loglik_op, backend="jax") super().__init__( data=data, - model_config=rlssm_config, + model_config=model_config, include=include, p_outlier=p_outlier, lapse=lapse, From fd99efbebe03e63694955cb69f59746b69f7435f Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 18 Mar 2026 13:52:30 -0400 Subject: [PATCH 06/22] Add validation for missing bounds in RLSSMConfig parameters --- src/hssm/config.py | 12 ++++++++++++ tests/test_rlssm_config.py | 8 ++++++++ 2 files changed, 20 insertions(+) diff --git a/src/hssm/config.py b/src/hssm/config.py index a2893660..46ba8811 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -461,6 +461,18 @@ def validate(self) -> None: f"match list_params length ({len(self.list_params)})" ) + # Every parameter must have bounds — get_defaults() returns (None, bounds) + # so missing bounds produce a cryptic "Bounds parameter unspecified" error + # deep in prior construction. Surface it here with a clear message. + if self.list_params: + missing_bounds = [p for p in self.list_params if p not in self.bounds] + if missing_bounds: + raise ValueError( + f"Missing bounds for parameter(s): {missing_bounds}. " + "Every parameter in `list_params` must have a corresponding " + "entry in `bounds`." + ) + def get_defaults( self, param: str ) -> tuple[float | None, tuple[float, float] | None]: diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 95a8f447..98729809 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -68,6 +68,7 @@ def _dummy_ssm_logp_func(x): model_name="test_model", list_params=["alpha", "beta"], params_default=[0.5, 0.3], + bounds={"alpha": (0.0, 1.0), "beta": (0.0, 1.0)}, decision_process="ddm", response=["rt", "response"], choices=[0, 1], @@ -244,6 +245,13 @@ def test_validate_ssm_logp_func_missing_annotations(self, valid_rlssmconfig_kwar ): config.validate() + def test_validate_missing_bounds_for_param(self, valid_rlssmconfig_kwargs): + """validate() should raise early when a param has no bounds entry.""" + kwargs = {**valid_rlssmconfig_kwargs, "bounds": {}} # strip all bounds + config = RLSSMConfig(**kwargs) + with pytest.raises(ValueError, match="Missing bounds for parameter"): + config.validate() + class TestRLSSMConfigDefaults: @pytest.mark.parametrize( From bc0f7ca61b78033b6ed51d0177678bf5968b5460 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 18 Mar 2026 14:05:57 -0400 Subject: [PATCH 07/22] Fix RLSSM to use model_config for ssm_logp_func and update test cases for default parameter bounds --- src/hssm/rl/rlssm.py | 2 +- tests/test_rlssm_config.py | 25 ++++++++++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 401bab07..429395fb 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -208,7 +208,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]: through :func:`~hssm.distribution_utils.make_likelihood_callable`. Instead it uses ``self.loglik`` directly — the differentiable pytensor ``Op`` built in :meth:`__init__` from - ``self.config.ssm_logp_func``. + ``self.model_config.ssm_logp_func``. The Op already handles: - The RL learning rule (computing trial-wise intermediate parameters). diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 98729809..dbd7171d 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -260,16 +260,35 @@ class TestRLSSMConfigDefaults: # params_default stores initialisation values, NOT priors. # get_defaults always returns None for the prior so that # prior_settings="safe" can assign priors from bounds. + # + # Case 1: queried param is present in bounds → bound returned. ( ["alpha", "beta", "gamma"], [0.5, 0.3, 0.2], - {"beta": (0.0, 1.0)}, + {"alpha": (0.0, 1.0), "beta": (0.0, 1.0), "gamma": (0.0, 1.0)}, "beta", None, (0.0, 1.0), ), - (["alpha", "beta"], [0.5, 0.3], {"alpha": (0.0, 1.0)}, "gamma", None, None), - (["alpha", "beta"], [], {"alpha": (0.0, 1.0)}, "alpha", None, (0.0, 1.0)), + # Case 2: queried param is NOT in list_params and NOT in bounds + # (e.g. a typo or an extra lookup) → both None. + ( + ["alpha", "beta"], + [0.5, 0.3], + {"alpha": (0.0, 1.0), "beta": (0.0, 1.0)}, + "gamma", + None, + None, + ), + # Case 3: params_default may be empty; param in bounds → bound returned. + ( + ["alpha", "beta"], + [], + {"alpha": (0.0, 1.0), "beta": (0.0, 1.0)}, + "alpha", + None, + (0.0, 1.0), + ), ], ) def test_get_defaults_cases( From b075e4fb04f94c4b3083276be1d0a6eab8e5277d Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 18 Mar 2026 14:50:41 -0400 Subject: [PATCH 08/22] Enhance RLSSM tests to align params_is_trialwise with list_params and add pickle round-trip verification --- tests/test_rlssm.py | 45 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index 0e60bfb2..e1ac73b9 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -218,14 +218,18 @@ def test_rlssm_deadline_raises( def test_rlssm_params_is_trialwise_aligned( rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig ) -> None: - """params_is_trialwise must align with list_params (same length, p_outlier=False).""" + """params_is_trialwise must align with self.list_params (HSSMBase-managed copy). + + self.list_params includes p_outlier when lapse is active; model_config.list_params + does not. The test must use self.list_params to cover the lapse parameter path. + """ model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) - assert model.model_config.list_params is not None - params_is_trialwise = [ - name != "p_outlier" for name in model.model_config.list_params - ] - assert len(params_is_trialwise) == len(model.model_config.list_params) - for name, is_tw in zip(model.model_config.list_params, params_is_trialwise): + assert model.list_params is not None + params_is_trialwise = [name != "p_outlier" for name in model.list_params] + assert len(params_is_trialwise) == len(model.list_params) + # p_outlier (added by HSSMBase when lapse is active) must be non-trialwise. + assert "p_outlier" in model.list_params, "fixture uses default p_outlier=0.05" + for name, is_tw in zip(model.list_params, params_is_trialwise): if name == "p_outlier": assert not is_tw, "p_outlier must be non-trialwise" else: @@ -313,3 +317,30 @@ def test_rlssm_sample_smoke(rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig) draws=4, tune=50, chains=1, cores=1, sampler="numpyro", target_accept=0.9 ) assert trace is not None + + +def test_rlssm_pickle_round_trip( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """cloudpickle round-trip must reconstruct an equivalent RLSSM. + + Verifies that __getstate__ / __setstate__ survive serialisation: + - The reconstructed object is a fresh RLSSM (not the same instance). + - n_participants and n_trials are preserved. + - list_params (including p_outlier) are preserved. + - model_config.model_name is preserved. + - model.model (bambi model) is rebuilt, confirming full re-initialisation. + """ + import cloudpickle + + model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + blob = cloudpickle.dumps(model) + restored = cloudpickle.loads(blob) + + assert restored is not model + assert isinstance(restored, RLSSM) + assert restored.n_participants == model.n_participants + assert restored.n_trials == model.n_trials + assert restored.list_params == model.list_params + assert restored.model_config.model_name == model.model_config.model_name + assert restored.model is not None From 27d505e3f98673f2cc940ec05b59c370b1be319d Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 18 Mar 2026 14:51:10 -0400 Subject: [PATCH 09/22] Add test to ensure RLSSMConfig.from_defaults raises NotImplementedError --- tests/test_rlssm_config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index dbd7171d..232d13a6 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -252,6 +252,11 @@ def test_validate_missing_bounds_for_param(self, valid_rlssmconfig_kwargs): with pytest.raises(ValueError, match="Missing bounds for parameter"): config.validate() + def test_from_defaults_raises(self): + """RLSSMConfig.from_defaults() must raise NotImplementedError.""" + with pytest.raises(NotImplementedError, match="from_defaults"): + RLSSMConfig.from_defaults("ddm", None) + class TestRLSSMConfigDefaults: @pytest.mark.parametrize( From ce8e187017638cdfcf8000939b375687de4eb3b4 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 18 Mar 2026 14:52:07 -0400 Subject: [PATCH 10/22] Clarify RLSSMConfig.from_defaults behavior and raise NotImplementedError for unsupported usage --- src/hssm/config.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index 46ba8811..b8e6992d 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -352,13 +352,32 @@ def __post_init__(self): """Set default loglik_kind for RLSSM models if not provided.""" if self.loglik_kind is None: self.loglik_kind = "approx_differentiable" + _logger.debug( + "RLSSMConfig: loglik_kind not specified; " + "defaulting to 'approx_differentiable'." + ) @classmethod def from_defaults( cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None - ) -> Config: - """Return the shared Config defaults (delegated to :class:`Config`).""" - return Config.from_defaults(model_name, loglik_kind) + ) -> "RLSSMConfig": + """Not supported for RLSSMConfig. + + RLSSM models are always constructed via + :meth:`RLSSMConfig.from_rlssm_dict` or directly via the constructor. + This override exists only to prevent accidental delegation to + :meth:`Config.from_defaults`, which returns a :class:`Config` and + would mislead callers expecting an :class:`RLSSMConfig`. + + Raises + ------ + NotImplementedError + Always. + """ + raise NotImplementedError( + "RLSSMConfig does not support from_defaults(). " + "Use RLSSMConfig.from_rlssm_dict() or the constructor directly." + ) @classmethod def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> RLSSMConfig: From 7c7fd32a787b1d2c72026ec1be3e045449894d00 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 18 Mar 2026 15:13:34 -0400 Subject: [PATCH 11/22] Inject JAX backend into RLSSMConfig during initialization --- src/hssm/rl/rlssm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 429395fb..61d52431 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -182,6 +182,10 @@ def __init__( # Build a new RLSSMConfig with the Op and backend injected, leaving # the caller's object unmodified (dataclasses.replace creates a shallow # copy with only the specified fields overridden). + # + # backend is hardcoded to "jax" because the entire RLSSM likelihood + # stack is JAX-only. See ssm_logp_func, make_rl_logp_op, and + # _make_model_distribution for details. model_config = replace(rlssm_config, loglik=loglik_op, backend="jax") super().__init__( From a3898d7d67732378f7abf8a00cc5b78239e70ed9 Mon Sep 17 00:00:00 2001 From: cpaniaguam Date: Thu, 19 Mar 2026 11:51:47 -0400 Subject: [PATCH 12/22] Fix merge conflicts with base branch --- src/hssm/rl/rlssm.py | 10 +-- tests/test_rlssm.py | 2 +- tests/test_rlssm_config.py | 139 ------------------------------------- 3 files changed, 6 insertions(+), 145 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index f231360b..2922dc49 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -187,11 +187,11 @@ def __init__( # backend is hardcoded to "jax" because the entire RLSSM likelihood # stack is JAX-only. See ssm_logp_func, make_rl_logp_op, and # _make_model_distribution for details. - model_config = replace(rlssm_config, loglik=loglik_op, backend="jax") - # Build a typed Config instance via RLSSMConfig's own factory method. - # The differentiable Op is passed so Config.validate() is satisfied; - # loglik_kind="approx_differentiable" reflects that the Op has gradients. - config = model_config._build_model_config(loglik_op) + model_config = replace(model_config, loglik=loglik_op, backend="jax") + # # Build a typed Config instance via RLSSMConfig's own factory method. + # # The differentiable Op is passed so Config.validate() is satisfied; + # # loglik_kind="approx_differentiable" reflects that the Op has gradients. + # config = model_config._build_model_config(loglik_op) super().__init__( data=data, diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index c1927461..b6124b18 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -313,7 +313,7 @@ def test_rlssm_pickle_round_trip( """ import cloudpickle - model = RLSSM(data=rldm_data, rlssm_config=rlssm_config) + model = RLSSM(data=rldm_data, model_config=rlssm_config) blob = cloudpickle.dumps(model) restored = cloudpickle.loads(blob) diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index d705a3d8..232d13a6 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -322,145 +322,6 @@ def test_get_defaults_cases( assert bounds_val == expected_bounds -class TestRLSSMConfigConversion: - @pytest.mark.parametrize( - "list_params, params_default, backend, expected_backend, expected_default_priors, raises", - [ - ( - ["alpha", "beta", "v", "a"], - [0.5, 0.3, 1.0, 1.5], - "jax", - "jax", - {"alpha": 0.5, "beta": 0.3, "v": 1.0, "a": 1.5}, - None, - ), - (["alpha"], [0.5], None, "jax", {"alpha": 0.5}, None), - (["alpha", "beta"], [], None, "jax", {}, None), - (["alpha", "beta", "gamma"], [0.5, 0.3], None, None, None, ValueError), - ], - ) - def test_to_config_cases( - self, - list_params, - params_default, - backend, - expected_backend, - expected_default_priors, - raises, - ): - model_config = RLSSMConfig( - model_name="test_model", - list_params=list_params, - params_default=params_default, - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - backend=backend, - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - if raises: - with pytest.raises(raises): - model_config.to_config() - else: - config = model_config.to_config() - assert config.backend == expected_backend - assert config.default_priors == expected_default_priors - - def test_to_config(self): - model_config = RLSSMConfig( - model_name="rlwm", - description="RLWM model", - list_params=["alpha", "beta", "v", "a"], - params_default=[0.5, 0.3, 1.0, 1.5], - bounds={ - "alpha": (0.0, 1.0), - "beta": (0.0, 1.0), - "v": (-3.0, 3.0), - "a": (0.3, 2.5), - }, - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - extra_fields=["feedback"], - backend="jax", - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - config = model_config.to_config() - assert isinstance(config, Config) - assert config.model_name == "rlwm" - assert config.description == "RLWM model" - assert config.list_params == ["alpha", "beta", "v", "a"] - assert config.response == ["rt", "response"] - assert config.choices == [0, 1] - assert config.extra_fields == ["feedback"] - assert config.backend == "jax" - assert config.loglik_kind == "approx_differentiable" - assert config.bounds == { - "alpha": (0.0, 1.0), - "beta": (0.0, 1.0), - "v": (-3.0, 3.0), - "a": (0.3, 2.5), - } - assert config.default_priors == { - "alpha": 0.5, - "beta": 0.3, - "v": 1.0, - "a": 1.5, - } - - def test_to_config_defaults_backend(self): - model_config = RLSSMConfig( - model_name="test_model", - list_params=["alpha"], - params_default=[0.5], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - config = model_config.to_config() - assert config.backend == "jax" - - def test_to_config_no_defaults(self): - model_config = RLSSMConfig( - model_name="test_model", - list_params=["alpha", "beta"], - params_default=[], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - config = model_config.to_config() - assert config.default_priors == {} - - def test_to_config_mismatched_defaults_length(self): - model_config = RLSSMConfig( - model_name="test_model", - list_params=["alpha", "beta", "gamma"], - params_default=[0.5, 0.3], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - with pytest.raises( - ValueError, - match=r"params_default length \(2\) doesn't match list_params length \(3\)", - ): - model_config.to_config() - - class TestRLSSMConfigLearningProcess: def test_learning_process(self): config = RLSSMConfig( From 4d99410bf894b4521e0ad4f71246e435e2850572 Mon Sep 17 00:00:00 2001 From: cpaniaguam Date: Thu, 19 Mar 2026 12:39:15 -0400 Subject: [PATCH 13/22] Remove commented out lines --- src/hssm/rl/rlssm.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 2922dc49..d8d02287 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -188,10 +188,6 @@ def __init__( # stack is JAX-only. See ssm_logp_func, make_rl_logp_op, and # _make_model_distribution for details. model_config = replace(model_config, loglik=loglik_op, backend="jax") - # # Build a typed Config instance via RLSSMConfig's own factory method. - # # The differentiable Op is passed so Config.validate() is satisfied; - # # loglik_kind="approx_differentiable" reflects that the Op has gradients. - # config = model_config._build_model_config(loglik_op) super().__init__( data=data, From f04f47ebd4fc2678ed4eba2f3cedadbdb8a45318 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Mar 2026 14:36:24 -0400 Subject: [PATCH 14/22] Remove RLSSMConfig import from __init__.py --- src/hssm/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/hssm/__init__.py b/src/hssm/__init__.py index 72aa14ba..2f234d08 100644 --- a/src/hssm/__init__.py +++ b/src/hssm/__init__.py @@ -11,7 +11,7 @@ import logging import sys -from .config import ModelConfig, RLSSMConfig +from .config import ModelConfig from .datasets import load_data from .defaults import show_defaults from .hssm import HSSM @@ -33,7 +33,6 @@ __all__ = [ "HSSM", "RLSSM", - "RLSSMConfig", "Link", "load_data", "ModelConfig", From 11115af36e87b5bcf2e1b41b24e6f4adffab3478 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Mar 2026 14:36:39 -0400 Subject: [PATCH 15/22] Reorganize import statements by moving RLSSMConfig import to the correct position --- src/hssm/rl/rlssm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index d8d02287..39fbfad2 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -27,7 +27,6 @@ from pytensor.graph import Op -from hssm.config import RLSSMConfig from hssm.defaults import ( INITVAL_JITTER_SETTINGS, ) @@ -36,6 +35,7 @@ from hssm.rl.utils import validate_balanced_panel from ..base import HSSMBase +from .config import RLSSMConfig class RLSSM(HSSMBase): From 6a9384f03df69fcfcd200c28e46861103a39e0e5 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Mar 2026 14:39:56 -0400 Subject: [PATCH 16/22] Move RLSSMConfig import to the correct module in test files --- tests/test_rlssm.py | 2 +- tests/test_rlssm_config.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index b6124b18..cf340480 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -15,7 +15,7 @@ import pytest import hssm -from hssm import RLSSM, RLSSMConfig +from hssm.rl import RLSSM, RLSSMConfig from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise from hssm.utils import annotate_function diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 232d13a6..404379fb 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -1,7 +1,8 @@ import pytest import hssm -from hssm.config import Config, ModelConfig, RLSSMConfig +from hssm.config import Config, ModelConfig +from hssm.rl import RLSSMConfig from hssm.utils import annotate_function # Define constants for repeated data structures From 0285f04f51bbbdd1225638ca69e4675b8d288ad3 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Mar 2026 14:40:26 -0400 Subject: [PATCH 17/22] Update docstring in __init__.py and exports --- src/hssm/rl/__init__.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/hssm/rl/__init__.py b/src/hssm/rl/__init__.py index 09eb646a..64e17bc4 100644 --- a/src/hssm/rl/__init__.py +++ b/src/hssm/rl/__init__.py @@ -1,18 +1,27 @@ -"""Reinforcement learning extensions for HSSM. +"""Reinforcement-learning extensions for HSSM. -This sub-package provides: +This subpackage groups components that integrate reinforcement-learning +learning rules with sequential-sampling decision models (SSMs). + +Public API (import from ``hssm.rl``): + +- ``RLSSM``: the RL + SSM model class implemented in :mod:`hssm.rl.rlssm`. +- ``RLSSMConfig``: the config class for RL + SSM models, implemented in + :mod:`hssm.rl.config`. +- ``validate_balanced_panel``: panel-balance utility in :mod:`hssm.rl.utils`. + +RL likelihood builders live in :mod:`hssm.rl.likelihoods.builder` and include +helpers such as :func:`~hssm.rl.likelihoods.builder.make_rl_logp_func` and +:func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`. -- :class:`~hssm.rl.rlssm.RLSSM` — the RL + SSM model class. -- :func:`~hssm.rl.utils.validate_balanced_panel` — panel-balance utility. -- :mod:`hssm.rl.likelihoods` — log-likelihood builders - (:func:`~hssm.rl.likelihoods.builder.make_rl_logp_func`, - :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`). """ +from .config import RLSSMConfig from .rlssm import RLSSM from .utils import validate_balanced_panel __all__ = [ "RLSSM", + "RLSSMConfig", "validate_balanced_panel", ] From 5807a71f0a71dffbc323ae42945e95cac5e9d736 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Mar 2026 14:40:34 -0400 Subject: [PATCH 18/22] Remove RLSSMConfig class and its associated methods from config.py --- src/hssm/config.py | 189 --------------------------------------------- 1 file changed, 189 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index b8e6992d..66348ab7 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -324,195 +324,6 @@ def _build_model_config( return config -@dataclass -class RLSSMConfig(BaseModelConfig): - """Config for reinforcement learning + sequential sampling models. - - This configuration class is designed for models that combine reinforcement - learning processes with sequential sampling decision models (RLSSM). - - The ``ssm_logp_func`` field holds the fully annotated JAX SSM log-likelihood - function (an :class:`AnnotatedFunction`) that is passed directly to - ``make_rl_logp_op``. It supersedes the ``loglik`` / ``loglik_kind`` workflow - used by :class:`HSSM`: the Op is built from ``ssm_logp_func`` and therefore - no ``loglik`` callable needs to be provided. - """ - - decision_process_loglik_kind: str = field(kw_only=True) - learning_process_loglik_kind: str = field(kw_only=True) - params_default: list[float] = field(kw_only=True) - decision_process: str | ModelConfig = field(kw_only=True) - learning_process: dict[str, Any] = field(kw_only=True) - # The fully annotated SSM log-likelihood function used by make_rl_logp_op. - # Type is Any to avoid a hard dependency on the AnnotatedFunction Protocol at - # import time; validated at runtime in validate(). - ssm_logp_func: Any = field(default=None, kw_only=True) - - def __post_init__(self): - """Set default loglik_kind for RLSSM models if not provided.""" - if self.loglik_kind is None: - self.loglik_kind = "approx_differentiable" - _logger.debug( - "RLSSMConfig: loglik_kind not specified; " - "defaulting to 'approx_differentiable'." - ) - - @classmethod - def from_defaults( - cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None - ) -> "RLSSMConfig": - """Not supported for RLSSMConfig. - - RLSSM models are always constructed via - :meth:`RLSSMConfig.from_rlssm_dict` or directly via the constructor. - This override exists only to prevent accidental delegation to - :meth:`Config.from_defaults`, which returns a :class:`Config` and - would mislead callers expecting an :class:`RLSSMConfig`. - - Raises - ------ - NotImplementedError - Always. - """ - raise NotImplementedError( - "RLSSMConfig does not support from_defaults(). " - "Use RLSSMConfig.from_rlssm_dict() or the constructor directly." - ) - - @classmethod - def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> RLSSMConfig: - """ - Create RLSSMConfig from a configuration dictionary. - - Parameters - ---------- - config_dict : dict[str, Any] - Dictionary containing model configuration. Expected keys: - - model_name: Model identifier (required) - - description: Model description (required) - - list_params: List of parameter names (required) - - extra_fields: List of extra field names from data (required) - - params_default: Default parameter values (required) - - bounds: Parameter bounds (required) - - response: Response column names (required) - - choices: Valid choice values (required) - - decision_process: Decision process specification (required) - - learning_process: Learning process functions (required) - - decision_process_loglik_kind: Likelihood kind for decision process - (required) - - learning_process_loglik_kind: Likelihood kind for learning process - (required) - - Returns - ------- - RLSSMConfig - Configured RLSSM model configuration object. - """ - # Check for required fields and raise explicit errors if missing - for field_name in RLSSM_REQUIRED_FIELDS: - if field_name not in config_dict or config_dict[field_name] is None: - raise ValueError(f"{field_name} must be provided in config_dict") - - return cls( - model_name=config_dict["model_name"], - description=config_dict["description"], - list_params=config_dict["list_params"], - extra_fields=config_dict.get("extra_fields"), - params_default=config_dict["params_default"], - decision_process=config_dict["decision_process"], - learning_process=config_dict["learning_process"], - ssm_logp_func=config_dict["ssm_logp_func"], - bounds=config_dict.get("bounds", {}), - response=config_dict["response"], - choices=config_dict["choices"], - decision_process_loglik_kind=config_dict["decision_process_loglik_kind"], - learning_process_loglik_kind=config_dict["learning_process_loglik_kind"], - ) - - def validate(self) -> None: - """Validate RLSSM configuration. - - Raises - ------ - ValueError - If required fields are missing or inconsistent. - """ - if self.response is None: - raise ValueError("Please provide `response` columns in the configuration.") - if self.list_params is None: - raise ValueError("Please provide `list_params` in the configuration.") - if self.choices is None: - raise ValueError("Please provide `choices` in the configuration.") - if self.decision_process is None: - raise ValueError("Please specify a `decision_process`.") - if self.ssm_logp_func is None: - raise ValueError( - "Please provide `ssm_logp_func`: the fully annotated JAX SSM " - "log-likelihood function required by `make_rl_logp_op`." - ) - if not callable(self.ssm_logp_func): - raise ValueError( - "`ssm_logp_func` must be a callable, " - f"but got {type(self.ssm_logp_func)!r}." - ) - missing_attrs = [ - attr - for attr in ("inputs", "outputs", "computed") - if not hasattr(self.ssm_logp_func, attr) - ] - if missing_attrs: - raise ValueError( - "`ssm_logp_func` must be decorated with `@annotate_function` " - "so that it carries the attributes required by `make_rl_logp_op`. " - f"Missing attribute(s): {missing_attrs}. " - "Decorate the function like:\n\n" - " @annotate_function(\n" - " inputs=[...], outputs=[...], computed={...}\n" - " )\n" - " def my_ssm_logp(lan_matrix): ..." - ) - - # Validate parameter defaults consistency - if self.params_default and self.list_params: - if len(self.params_default) != len(self.list_params): - raise ValueError( - f"params_default length ({len(self.params_default)}) doesn't " - f"match list_params length ({len(self.list_params)})" - ) - - # Every parameter must have bounds — get_defaults() returns (None, bounds) - # so missing bounds produce a cryptic "Bounds parameter unspecified" error - # deep in prior construction. Surface it here with a clear message. - if self.list_params: - missing_bounds = [p for p in self.list_params if p not in self.bounds] - if missing_bounds: - raise ValueError( - f"Missing bounds for parameter(s): {missing_bounds}. " - "Every parameter in `list_params` must have a corresponding " - "entry in `bounds`." - ) - - def get_defaults( - self, param: str - ) -> tuple[float | None, tuple[float, float] | None]: - """Return default value and bounds for a parameter. - - Parameters - ---------- - param - The name of the parameter. - - Returns - ------- - tuple - ``(None, bounds)`` — the default prior is intentionally ``None`` so that - ``prior_settings="safe"`` in :class:`~hssm.base.HSSMBase` assigns - priors from bounds rather than fixing parameters to scalar constants. - ``params_default`` stores initialisation values, not prior distributions. - """ - return None, self.bounds.get(param) - - @dataclass class ModelConfig: """Representation for model_config provided by the user.""" From 4bf67ea666c8ef90aa09b0b6787b7c7bbab70ced Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Mar 2026 14:41:02 -0400 Subject: [PATCH 19/22] Move RLSSMConfig class hssm.rl module --- src/hssm/rl/config.py | 148 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 src/hssm/rl/config.py diff --git a/src/hssm/rl/config.py b/src/hssm/rl/config.py new file mode 100644 index 00000000..8ca9bf8c --- /dev/null +++ b/src/hssm/rl/config.py @@ -0,0 +1,148 @@ +"""RL-specific configuration classes. + +This module houses `RLSSMConfig` which was previously defined in +`hssm.config`. It is intentionally lightweight and re-uses +`BaseModelConfig` from :mod:`hssm.config` to avoid duplicating core +behaviour. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .._types import LoglikKind, SupportedModels + from ..config import ModelConfig + +from ..config import BaseModelConfig + +_logger = logging.getLogger("hssm") + +# Local copy of required fields for RLSSM configs. Kept here so the class +# can be imported without importing the entirety of `hssm.config`'s runtime +# machinery earlier than necessary. +RLSSM_REQUIRED_FIELDS = ( + "model_name", + "description", + "list_params", + "bounds", + "params_default", + "choices", + "decision_process", + "learning_process", + "response", + "decision_process_loglik_kind", + "learning_process_loglik_kind", + "extra_fields", + "ssm_logp_func", +) + + +@dataclass +class RLSSMConfig(BaseModelConfig): + """Config for reinforcement learning + sequential sampling models. + + The ``ssm_logp_func`` field holds the fully annotated JAX SSM + log-likelihood function (an :class:`AnnotatedFunction`) that is passed + directly to ``make_rl_logp_op``. + """ + + decision_process_loglik_kind: str = field(kw_only=True) + learning_process_loglik_kind: str = field(kw_only=True) + params_default: list[float] = field(kw_only=True) + decision_process: str | "ModelConfig" = field(kw_only=True) + learning_process: dict[str, Any] = field(kw_only=True) + ssm_logp_func: Any = field(default=None, kw_only=True) + + def __post_init__(self): # noqa: D105 + if self.loglik_kind is None: + self.loglik_kind = "approx_differentiable" + _logger.debug( + "RLSSMConfig: loglik_kind not specified; " + "defaulting to 'approx_differentiable'." + ) + + @classmethod + def from_defaults( # noqa: D102 + cls, model_name: "SupportedModels" | str, loglik_kind: "LoglikKind" | None + ) -> "RLSSMConfig": + raise NotImplementedError( + "RLSSMConfig does not support from_defaults(). " + "Use RLSSMConfig.from_rlssm_dict() or the constructor directly." + ) + + @classmethod + def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> "RLSSMConfig": # noqa: D102 + for field_name in RLSSM_REQUIRED_FIELDS: + if field_name not in config_dict or config_dict[field_name] is None: + raise ValueError(f"{field_name} must be provided in config_dict") + + return cls( + model_name=config_dict["model_name"], + description=config_dict["description"], + list_params=config_dict["list_params"], + extra_fields=config_dict.get("extra_fields"), + params_default=config_dict["params_default"], + decision_process=config_dict["decision_process"], + learning_process=config_dict["learning_process"], + ssm_logp_func=config_dict["ssm_logp_func"], + bounds=config_dict.get("bounds", {}), + response=config_dict["response"], + choices=config_dict["choices"], + decision_process_loglik_kind=config_dict["decision_process_loglik_kind"], + learning_process_loglik_kind=config_dict["learning_process_loglik_kind"], + ) + + def validate(self) -> None: # noqa: D102 + if self.response is None: + raise ValueError("Please provide `response` columns in the configuration.") + if self.list_params is None: + raise ValueError("Please provide `list_params` in the configuration.") + if self.choices is None: + raise ValueError("Please provide `choices` in the configuration.") + if self.decision_process is None: + raise ValueError("Please specify a `decision_process`.") + if self.ssm_logp_func is None: + raise ValueError( + "Please provide `ssm_logp_func`: the fully annotated JAX SSM " + "log-likelihood function required by `make_rl_logp_op`." + ) + if not callable(self.ssm_logp_func): + raise ValueError( + "`ssm_logp_func` must be a callable, " + f"but got {type(self.ssm_logp_func)!r}." + ) + missing_attrs = [ + attr + for attr in ("inputs", "outputs", "computed") + if not hasattr(self.ssm_logp_func, attr) + ] + if missing_attrs: + raise ValueError( + "`ssm_logp_func` must be decorated with `@annotate_function` " + "so that it carries the attributes required by `make_rl_logp_op`. " + f"Missing attribute(s): {missing_attrs}. " + ) + + if self.params_default and self.list_params: + if len(self.params_default) != len(self.list_params): + raise ValueError( + f"params_default length ({len(self.params_default)}) doesn't " + f"match list_params length ({len(self.list_params)})" + ) + + if self.list_params: + missing_bounds = [p for p in self.list_params if p not in self.bounds] + if missing_bounds: + raise ValueError( + f"Missing bounds for parameter(s): {missing_bounds}. " + "Every parameter in `list_params` must have a corresponding " + "entry in `bounds`." + ) + + def get_defaults( # noqa: D102 + self, param: str + ) -> tuple[float | None, tuple[float, float] | None]: + return None, self.bounds.get(param) From 5d74bfe0ae0c9989c4ca3caa002f60af78ff3ef3 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 27 Mar 2026 10:54:58 -0400 Subject: [PATCH 20/22] Refactor config.py to remove RLSSM-specific defaults and unify observed data constants --- src/hssm/config.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/src/hssm/config.py b/src/hssm/config.py index 66348ab7..4c71658c 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -27,27 +27,10 @@ _logger = logging.getLogger("hssm") -# ====== Centralized RLSSM defaults ===== +# ====== Centralized SSM defaults ===== DEFAULT_SSM_OBSERVED_DATA = ["rt", "response"] -DEFAULT_RLSSM_OBSERVED_DATA = ["rt", "response"] DEFAULT_SSM_CHOICES = (0, 1) -RLSSM_REQUIRED_FIELDS = ( - "model_name", - "description", - "list_params", - "bounds", - "params_default", - "choices", - "decision_process", - "learning_process", - "response", - "decision_process_loglik_kind", - "learning_process_loglik_kind", - "extra_fields", - "ssm_logp_func", -) - ParamSpec = Union[float, dict[str, Any], Prior, None] @@ -196,7 +179,7 @@ def from_defaults( return Config( model_name=model_name, loglik_kind=loglik_kind, - response=DEFAULT_RLSSM_OBSERVED_DATA, + response=DEFAULT_SSM_OBSERVED_DATA, ) def update_loglik(self, loglik: Any | None) -> None: From 91b1098387acd6bc1cfac1e076e8f94bd3a44c21 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 27 Mar 2026 11:59:20 -0400 Subject: [PATCH 21/22] Enhance validation in RLSSMConfig for ssm_logp_func attributes --- src/hssm/rl/config.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/hssm/rl/config.py b/src/hssm/rl/config.py index 8ca9bf8c..dc8e1d2c 100644 --- a/src/hssm/rl/config.py +++ b/src/hssm/rl/config.py @@ -104,20 +104,21 @@ def validate(self) -> None: # noqa: D102 raise ValueError("Please provide `choices` in the configuration.") if self.decision_process is None: raise ValueError("Please specify a `decision_process`.") - if self.ssm_logp_func is None: + + logpfunc = self.ssm_logp_func + if logpfunc is None: raise ValueError( "Please provide `ssm_logp_func`: the fully annotated JAX SSM " "log-likelihood function required by `make_rl_logp_op`." ) - if not callable(self.ssm_logp_func): + if not callable(logpfunc): raise ValueError( - "`ssm_logp_func` must be a callable, " - f"but got {type(self.ssm_logp_func)!r}." + f"`ssm_logp_func` must be a callable, but got {type(logpfunc)!r}." ) missing_attrs = [ attr for attr in ("inputs", "outputs", "computed") - if not hasattr(self.ssm_logp_func, attr) + if not hasattr(logpfunc, attr) ] if missing_attrs: raise ValueError( @@ -126,6 +127,13 @@ def validate(self) -> None: # noqa: D102 f"Missing attribute(s): {missing_attrs}. " ) + if not isinstance(logpfunc.computed, dict) or not all( + callable(v) for v in logpfunc.computed.values() + ): + raise ValueError( + "`ssm_logp_func.computed` must be a dictionary with callable values." + ) + if self.params_default and self.list_params: if len(self.params_default) != len(self.list_params): raise ValueError( From c3a4f527bb1aed125f642f99378fc44eed999b6e Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 27 Mar 2026 12:09:53 -0400 Subject: [PATCH 22/22] Add validation test for non-callable values in ssm_logp_func.computed --- tests/test_rlssm_config.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 404379fb..513f4274 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -246,6 +246,20 @@ def test_validate_ssm_logp_func_missing_annotations(self, valid_rlssmconfig_kwar ): config.validate() + def test_validate_ssm_logp_func_computed_not_callable( + self, valid_rlssmconfig_kwargs + ): + """`computed` exists but contains non-callable values -> error.""" + config = RLSSMConfig(**valid_rlssmconfig_kwargs) + # Inject a computed mapping with a non-callable value to trigger the + # specific validation branch. + config.ssm_logp_func.computed = {"x": "not_callable"} + with pytest.raises( + ValueError, + match=r"`ssm_logp_func.computed` must be a dictionary with callable values\.", + ): + config.validate() + def test_validate_missing_bounds_for_param(self, valid_rlssmconfig_kwargs): """validate() should raise early when a param has no bounds entry.""" kwargs = {**valid_rlssmconfig_kwargs, "bounds": {}} # strip all bounds