diff --git a/src/hssm/__init__.py b/src/hssm/__init__.py index 72aa14ba..2f234d08 100644 --- a/src/hssm/__init__.py +++ b/src/hssm/__init__.py @@ -11,7 +11,7 @@ import logging import sys -from .config import ModelConfig, RLSSMConfig +from .config import ModelConfig from .datasets import load_data from .defaults import show_defaults from .hssm import HSSM @@ -33,7 +33,6 @@ __all__ = [ "HSSM", "RLSSM", - "RLSSMConfig", "Link", "load_data", "ModelConfig", diff --git a/src/hssm/base.py b/src/hssm/base.py index 95d5c97c..02c8c1da 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -273,7 +273,11 @@ def __init__( if self.model_config.response is not None else None ) - self.list_params = self.model_config.list_params + self.list_params = ( + list(self.model_config.list_params) + if self.model_config.list_params is not None + else None + ) self.choices = self.model_config.choices # type: ignore[assignment] self.model_name = self.model_config.model_name self.loglik = self.model_config.loglik diff --git a/src/hssm/config.py b/src/hssm/config.py index 4b75ba3f..4c71658c 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -27,27 +27,10 @@ _logger = logging.getLogger("hssm") -# ====== Centralized RLSSM defaults ===== +# ====== Centralized SSM defaults ===== DEFAULT_SSM_OBSERVED_DATA = ["rt", "response"] -DEFAULT_RLSSM_OBSERVED_DATA = ["rt", "response"] DEFAULT_SSM_CHOICES = (0, 1) -RLSSM_REQUIRED_FIELDS = ( - "model_name", - "description", - "list_params", - "bounds", - "params_default", - "choices", - "decision_process", - "learning_process", - "response", - "decision_process_loglik_kind", - "learning_process_loglik_kind", - "extra_fields", - "ssm_logp_func", -) - ParamSpec = Union[float, dict[str, Any], Prior, None] @@ -196,7 +179,7 @@ def from_defaults( return Config( model_name=model_name, loglik_kind=loglik_kind, - response=DEFAULT_RLSSM_OBSERVED_DATA, + response=DEFAULT_SSM_OBSERVED_DATA, ) def update_loglik(self, loglik: Any | None) -> None: @@ -324,274 +307,6 @@ def _build_model_config( return config -@dataclass -class RLSSMConfig(BaseModelConfig): - """Config for reinforcement learning + sequential sampling models. - - This configuration class is designed for models that combine reinforcement - learning processes with sequential sampling decision models (RLSSM). - - The ``ssm_logp_func`` field holds the fully annotated JAX SSM log-likelihood - function (an :class:`AnnotatedFunction`) that is passed directly to - ``make_rl_logp_op``. It supersedes the ``loglik`` / ``loglik_kind`` workflow - used by :class:`HSSM`: the Op is built from ``ssm_logp_func`` and therefore - no ``loglik`` callable needs to be provided. - """ - - decision_process_loglik_kind: str = field(kw_only=True) - learning_process_loglik_kind: str = field(kw_only=True) - params_default: list[float] = field(kw_only=True) - decision_process: str | ModelConfig = field(kw_only=True) - learning_process: dict[str, Any] = field(kw_only=True) - # The fully annotated SSM log-likelihood function used by make_rl_logp_op. - # Type is Any to avoid a hard dependency on the AnnotatedFunction Protocol at - # import time; validated at runtime in validate(). - ssm_logp_func: Any = field(default=None, kw_only=True) - - def __post_init__(self): - """Set default loglik_kind for RLSSM models if not provided.""" - if self.loglik_kind is None: - self.loglik_kind = "approx_differentiable" - - @classmethod - def from_defaults( - cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None - ) -> Config: - """Return the shared Config defaults (delegated to :class:`Config`).""" - return Config.from_defaults(model_name, loglik_kind) - - @classmethod - def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> RLSSMConfig: - """ - Create RLSSMConfig from a configuration dictionary. - - Parameters - ---------- - config_dict : dict[str, Any] - Dictionary containing model configuration. Expected keys: - - model_name: Model identifier (required) - - description: Model description (required) - - list_params: List of parameter names (required) - - extra_fields: List of extra field names from data (required) - - params_default: Default parameter values (required) - - bounds: Parameter bounds (required) - - response: Response column names (required) - - choices: Valid choice values (required) - - decision_process: Decision process specification (required) - - learning_process: Learning process functions (required) - - decision_process_loglik_kind: Likelihood kind for decision process - (required) - - learning_process_loglik_kind: Likelihood kind for learning process - (required) - - Returns - ------- - RLSSMConfig - Configured RLSSM model configuration object. - """ - # Check for required fields and raise explicit errors if missing - for field_name in RLSSM_REQUIRED_FIELDS: - if field_name not in config_dict or config_dict[field_name] is None: - raise ValueError(f"{field_name} must be provided in config_dict") - - return cls( - model_name=config_dict["model_name"], - description=config_dict["description"], - list_params=config_dict["list_params"], - extra_fields=config_dict.get("extra_fields"), - params_default=config_dict["params_default"], - decision_process=config_dict["decision_process"], - learning_process=config_dict["learning_process"], - ssm_logp_func=config_dict["ssm_logp_func"], - bounds=config_dict.get("bounds", {}), - response=config_dict["response"], - choices=config_dict["choices"], - decision_process_loglik_kind=config_dict["decision_process_loglik_kind"], - learning_process_loglik_kind=config_dict["learning_process_loglik_kind"], - ) - - def validate(self) -> None: - """Validate RLSSM configuration. - - Raises - ------ - ValueError - If required fields are missing or inconsistent. - """ - if self.response is None: - raise ValueError("Please provide `response` columns in the configuration.") - if self.list_params is None: - raise ValueError("Please provide `list_params` in the configuration.") - if self.choices is None: - raise ValueError("Please provide `choices` in the configuration.") - if self.decision_process is None: - raise ValueError("Please specify a `decision_process`.") - if self.ssm_logp_func is None: - raise ValueError( - "Please provide `ssm_logp_func`: the fully annotated JAX SSM " - "log-likelihood function required by `make_rl_logp_op`." - ) - if not callable(self.ssm_logp_func): - raise ValueError( - "`ssm_logp_func` must be a callable, " - f"but got {type(self.ssm_logp_func)!r}." - ) - missing_attrs = [ - attr - for attr in ("inputs", "outputs", "computed") - if not hasattr(self.ssm_logp_func, attr) - ] - if missing_attrs: - raise ValueError( - "`ssm_logp_func` must be decorated with `@annotate_function` " - "so that it carries the attributes required by `make_rl_logp_op`. " - f"Missing attribute(s): {missing_attrs}. " - "Decorate the function like:\n\n" - " @annotate_function(\n" - " inputs=[...], outputs=[...], computed={...}\n" - " )\n" - " def my_ssm_logp(lan_matrix): ..." - ) - - # Validate parameter defaults consistency - if self.params_default and self.list_params: - if len(self.params_default) != len(self.list_params): - raise ValueError( - f"params_default length ({len(self.params_default)}) doesn't " - f"match list_params length ({len(self.list_params)})" - ) - - def get_defaults( - self, param: str - ) -> tuple[float | None, tuple[float, float] | None]: - """Return default value and bounds for a parameter. - - Parameters - ---------- - param - The name of the parameter. - - Returns - ------- - tuple - A tuple of (default_value, bounds) where: - - default_value is a float or None if not found - - bounds is a tuple (lower, upper) or None if not found - """ - # Try to find the parameter in list_params and get its default value - default_val = None - if self.list_params is not None: - try: - param_idx = self.list_params.index(param) - if self.params_default and param_idx < len(self.params_default): - default_val = self.params_default[param_idx] - except ValueError: - # Parameter not in list_params - pass - - return default_val, self.bounds.get(param) - - def to_config(self) -> Config: - """Convert to standard Config for compatibility with HSSM. - - This method transforms the RLSSM configuration into a standard Config - object that can be used with the existing HSSM infrastructure. - - Returns - ------- - Config - A Config object with RLSSM parameters mapped to standard format. - - Notes - ----- - The transformation converts params_default list to default_priors dict, - mapping parameter names to their default values. - """ - # Validate parameter defaults consistency before conversion - if self.params_default and self.list_params: - if len(self.params_default) != len(self.list_params): - raise ValueError( - f"params_default length ({len(self.params_default)}) doesn't " - f"match list_params length ({len(self.list_params)}). " - "This would result in silent data loss during conversion." - ) - - # Transform params_default list to default_priors dict - default_priors = ( - { - param: default - for param, default in zip(self.list_params, self.params_default) - } - if self.list_params and self.params_default - else {} - ) - - return Config( - model_name=self.model_name, - loglik_kind=self.loglik_kind, - response=self.response, - choices=self.choices, - list_params=self.list_params, - description=self.description, - bounds=self.bounds, - default_priors=cast( - "dict[str, float | dict[str, Any] | Any | None]", default_priors - ), - extra_fields=self.extra_fields, - backend=self.backend or "jax", # RLSSM typically uses JAX - loglik=self.loglik, - ) - - def to_model_config(self) -> ModelConfig: - """Build a :class:`ModelConfig` from this :class:`RLSSMConfig`. - - All fields are sourced from ``self``; the backend is fixed to ``"jax"`` - because RLSSM exclusively uses the JAX backend. - - ``default_priors`` is intentionally left empty so the - ``prior_settings="safe"`` mechanism in :class:`~hssm.base.HSSMBase` - assigns sensible priors from bounds rather than fixing every parameter - to a constant scalar. - """ - return ModelConfig( - response=tuple(self.response), # type: ignore[arg-type] - list_params=list(self.list_params), # type: ignore[arg-type] - choices=tuple(self.choices), # type: ignore[arg-type] - default_priors={}, - bounds=self.bounds, - extra_fields=self.extra_fields, - backend="jax", - ) - - def _build_model_config(self, loglik_op: Any) -> Config: - """Build a validated :class:`Config` for use by :class:`~hssm.rl.rlssm.RLSSM`. - - Converts this :class:`RLSSMConfig` to a :class:`ModelConfig`, then - delegates to :meth:`Config._build_model_config` using the pre-built - differentiable Op as ``loglik``. - - Parameters - ---------- - loglik_op - The differentiable pytensor Op produced by - :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`. - - Returns - ------- - Config - A fully validated :class:`Config` ready to pass to - :meth:`~hssm.base.HSSMBase.__init__`. - """ - mc = self.to_model_config() - return Config._build_model_config( - self.model_name, - "approx_differentiable", - mc, - None, - loglik_op, - ) - - @dataclass class ModelConfig: """Representation for model_config provided by the user.""" diff --git a/src/hssm/rl/__init__.py b/src/hssm/rl/__init__.py index 09eb646a..64e17bc4 100644 --- a/src/hssm/rl/__init__.py +++ b/src/hssm/rl/__init__.py @@ -1,18 +1,27 @@ -"""Reinforcement learning extensions for HSSM. +"""Reinforcement-learning extensions for HSSM. -This sub-package provides: +This subpackage groups components that integrate reinforcement-learning +learning rules with sequential-sampling decision models (SSMs). + +Public API (import from ``hssm.rl``): + +- ``RLSSM``: the RL + SSM model class implemented in :mod:`hssm.rl.rlssm`. +- ``RLSSMConfig``: the config class for RL + SSM models, implemented in + :mod:`hssm.rl.config`. +- ``validate_balanced_panel``: panel-balance utility in :mod:`hssm.rl.utils`. + +RL likelihood builders live in :mod:`hssm.rl.likelihoods.builder` and include +helpers such as :func:`~hssm.rl.likelihoods.builder.make_rl_logp_func` and +:func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`. -- :class:`~hssm.rl.rlssm.RLSSM` — the RL + SSM model class. -- :func:`~hssm.rl.utils.validate_balanced_panel` — panel-balance utility. -- :mod:`hssm.rl.likelihoods` — log-likelihood builders - (:func:`~hssm.rl.likelihoods.builder.make_rl_logp_func`, - :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`). """ +from .config import RLSSMConfig from .rlssm import RLSSM from .utils import validate_balanced_panel __all__ = [ "RLSSM", + "RLSSMConfig", "validate_balanced_panel", ] diff --git a/src/hssm/rl/config.py b/src/hssm/rl/config.py new file mode 100644 index 00000000..dc8e1d2c --- /dev/null +++ b/src/hssm/rl/config.py @@ -0,0 +1,156 @@ +"""RL-specific configuration classes. + +This module houses `RLSSMConfig` which was previously defined in +`hssm.config`. It is intentionally lightweight and re-uses +`BaseModelConfig` from :mod:`hssm.config` to avoid duplicating core +behaviour. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .._types import LoglikKind, SupportedModels + from ..config import ModelConfig + +from ..config import BaseModelConfig + +_logger = logging.getLogger("hssm") + +# Local copy of required fields for RLSSM configs. Kept here so the class +# can be imported without importing the entirety of `hssm.config`'s runtime +# machinery earlier than necessary. +RLSSM_REQUIRED_FIELDS = ( + "model_name", + "description", + "list_params", + "bounds", + "params_default", + "choices", + "decision_process", + "learning_process", + "response", + "decision_process_loglik_kind", + "learning_process_loglik_kind", + "extra_fields", + "ssm_logp_func", +) + + +@dataclass +class RLSSMConfig(BaseModelConfig): + """Config for reinforcement learning + sequential sampling models. + + The ``ssm_logp_func`` field holds the fully annotated JAX SSM + log-likelihood function (an :class:`AnnotatedFunction`) that is passed + directly to ``make_rl_logp_op``. + """ + + decision_process_loglik_kind: str = field(kw_only=True) + learning_process_loglik_kind: str = field(kw_only=True) + params_default: list[float] = field(kw_only=True) + decision_process: str | "ModelConfig" = field(kw_only=True) + learning_process: dict[str, Any] = field(kw_only=True) + ssm_logp_func: Any = field(default=None, kw_only=True) + + def __post_init__(self): # noqa: D105 + if self.loglik_kind is None: + self.loglik_kind = "approx_differentiable" + _logger.debug( + "RLSSMConfig: loglik_kind not specified; " + "defaulting to 'approx_differentiable'." + ) + + @classmethod + def from_defaults( # noqa: D102 + cls, model_name: "SupportedModels" | str, loglik_kind: "LoglikKind" | None + ) -> "RLSSMConfig": + raise NotImplementedError( + "RLSSMConfig does not support from_defaults(). " + "Use RLSSMConfig.from_rlssm_dict() or the constructor directly." + ) + + @classmethod + def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> "RLSSMConfig": # noqa: D102 + for field_name in RLSSM_REQUIRED_FIELDS: + if field_name not in config_dict or config_dict[field_name] is None: + raise ValueError(f"{field_name} must be provided in config_dict") + + return cls( + model_name=config_dict["model_name"], + description=config_dict["description"], + list_params=config_dict["list_params"], + extra_fields=config_dict.get("extra_fields"), + params_default=config_dict["params_default"], + decision_process=config_dict["decision_process"], + learning_process=config_dict["learning_process"], + ssm_logp_func=config_dict["ssm_logp_func"], + bounds=config_dict.get("bounds", {}), + response=config_dict["response"], + choices=config_dict["choices"], + decision_process_loglik_kind=config_dict["decision_process_loglik_kind"], + learning_process_loglik_kind=config_dict["learning_process_loglik_kind"], + ) + + def validate(self) -> None: # noqa: D102 + if self.response is None: + raise ValueError("Please provide `response` columns in the configuration.") + if self.list_params is None: + raise ValueError("Please provide `list_params` in the configuration.") + if self.choices is None: + raise ValueError("Please provide `choices` in the configuration.") + if self.decision_process is None: + raise ValueError("Please specify a `decision_process`.") + + logpfunc = self.ssm_logp_func + if logpfunc is None: + raise ValueError( + "Please provide `ssm_logp_func`: the fully annotated JAX SSM " + "log-likelihood function required by `make_rl_logp_op`." + ) + if not callable(logpfunc): + raise ValueError( + f"`ssm_logp_func` must be a callable, but got {type(logpfunc)!r}." + ) + missing_attrs = [ + attr + for attr in ("inputs", "outputs", "computed") + if not hasattr(logpfunc, attr) + ] + if missing_attrs: + raise ValueError( + "`ssm_logp_func` must be decorated with `@annotate_function` " + "so that it carries the attributes required by `make_rl_logp_op`. " + f"Missing attribute(s): {missing_attrs}. " + ) + + if not isinstance(logpfunc.computed, dict) or not all( + callable(v) for v in logpfunc.computed.values() + ): + raise ValueError( + "`ssm_logp_func.computed` must be a dictionary with callable values." + ) + + if self.params_default and self.list_params: + if len(self.params_default) != len(self.list_params): + raise ValueError( + f"params_default length ({len(self.params_default)}) doesn't " + f"match list_params length ({len(self.list_params)})" + ) + + if self.list_params: + missing_bounds = [p for p in self.list_params if p not in self.bounds] + if missing_bounds: + raise ValueError( + f"Missing bounds for parameter(s): {missing_bounds}. " + "Every parameter in `list_params` must have a corresponding " + "entry in `bounds`." + ) + + def get_defaults( # noqa: D102 + self, param: str + ) -> tuple[float | None, tuple[float, float] | None]: + return None, self.bounds.get(param) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index c76a0a81..39fbfad2 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -16,6 +16,7 @@ standard ``loglik`` / ``loglik_kind`` wrapping pipeline. """ +from dataclasses import replace from typing import TYPE_CHECKING, Any, Callable, Literal, cast import bambi as bmb @@ -26,7 +27,6 @@ from pytensor.graph import Op -from hssm.config import Config, RLSSMConfig from hssm.defaults import ( INITVAL_JITTER_SETTINGS, ) @@ -35,6 +35,7 @@ from hssm.rl.utils import validate_balanced_panel from ..base import HSSMBase +from .config import RLSSMConfig class RLSSM(HSSMBase): @@ -97,8 +98,9 @@ class RLSSM(HSSMBase): Attributes ---------- - config : RLSSMConfig - The RLSSM configuration object. + model_config : RLSSMConfig + The RLSSM configuration object (stored as ``self.model_config`` on + :class:`~hssm.base.HSSMBase` with the built ``loglik`` Op injected). n_participants : int Number of participants inferred from *data*. n_trials : int @@ -161,9 +163,9 @@ def __init__( self.n_trials = n_trials # Build the differentiable pytensor Op from the annotated SSM function. - # This Op supersedes the loglik/loglik_kind workflow: it is passed as - # `loglik` to HSSMBase so Config.validate() is satisfied, and - # _make_model_distribution() uses it directly without any further wrapping. + # This Op supersedes the loglik/loglik_kind workflow: it is stored on + # rlssm_config.loglik so that HSSMBase can access it uniformly via + # self.model_config.loglik, without any Config conversion. # # Fresh list() copies are passed to make_rl_logp_op so the closure inside # captures its own isolated list objects. HSSMBase will later append @@ -178,14 +180,18 @@ def __init__( extra_fields=list(model_config.extra_fields or []), ) - # Build a typed Config instance via RLSSMConfig's own factory method. - # The differentiable Op is passed so Config.validate() is satisfied; - # loglik_kind="approx_differentiable" reflects that the Op has gradients. - config = model_config._build_model_config(loglik_op) + # Build a new RLSSMConfig with the Op and backend injected, leaving + # the caller's object unmodified (dataclasses.replace creates a shallow + # copy with only the specified fields overridden). + # + # backend is hardcoded to "jax" because the entire RLSSM likelihood + # stack is JAX-only. See ssm_logp_func, make_rl_logp_op, and + # _make_model_distribution for details. + model_config = replace(model_config, loglik=loglik_op, backend="jax") super().__init__( data=data, - model_config=config, + model_config=model_config, include=include, p_outlier=p_outlier, lapse=lapse, @@ -207,7 +213,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]: through :func:`~hssm.distribution_utils.make_likelihood_callable`. Instead it uses ``self.loglik`` directly — the differentiable pytensor ``Op`` built in :meth:`__init__` from - ``self.config.ssm_logp_func``. + ``self.model_config.ssm_logp_func``. The Op already handles: - The RL learning rule (computing trial-wise intermediate parameters). @@ -218,10 +224,13 @@ def _make_model_distribution(self) -> type[pm.Distribution]: RLSSM and ``missing_data`` / ``deadline`` are rejected in ``__init__`` before this method is ever reached. """ - list_params = self.model_config.list_params - assert list_params is not None, "model_config.list_params must be set" + # Use self.list_params (managed by HSSMBase, includes p_outlier when + # has_lapse=True) rather than self.model_config.list_params (the original + # config list, never mutated by HSSMBase). + list_params = self.list_params + assert list_params is not None, "list_params must be set" assert isinstance(list_params, list), ( - "model_config.list_params must be a list" + "list_params must be a list" ) # for type checker # p_outlier is a scalar mixture weight (not trialwise); every other @@ -235,15 +244,13 @@ def _make_model_distribution(self) -> type[pm.Distribution]: else [self.data[field].to_numpy(copy=True) for field in extra_fields] ) - # The differentiable pytensor Op was stored on the validated model_config - # during __init__ as its `loglik`; ensure it's present and cast for typing. + # The differentiable pytensor Op was stored on model_config.loglik during + # __init__; ensure it's present and cast for typing. assert self.model_config.loglik is not None, "model_config.loglik must be set" loglik_op = cast("Callable[..., Any] | Op", self.model_config.loglik) - # `model_config` is typed as BaseModelConfig on the base class; cast - # to `Config` here so static checkers understand `rv` exists. - cfg = cast("Config", self.model_config) - rv_name = cfg.rv or cfg.model_name + # RLSSMConfig carries no `rv` field; use model_name as the rv identifier. + rv_name = self.model_config.model_name return make_distribution( rv=rv_name, diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index 98b34181..cf340480 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -15,7 +15,7 @@ import pytest import hssm -from hssm import RLSSM, RLSSMConfig +from hssm.rl import RLSSM, RLSSMConfig from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise from hssm.utils import annotate_function @@ -297,3 +297,30 @@ def test_rlssm_sample_smoke(rldm_data, rlssm_config) -> None: draws=4, tune=50, chains=1, cores=1, sampler="numpyro", target_accept=0.9 ) assert trace is not None + + +def test_rlssm_pickle_round_trip( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """cloudpickle round-trip must reconstruct an equivalent RLSSM. + + Verifies that __getstate__ / __setstate__ survive serialisation: + - The reconstructed object is a fresh RLSSM (not the same instance). + - n_participants and n_trials are preserved. + - list_params (including p_outlier) are preserved. + - model_config.model_name is preserved. + - model.model (bambi model) is rebuilt, confirming full re-initialisation. + """ + import cloudpickle + + model = RLSSM(data=rldm_data, model_config=rlssm_config) + blob = cloudpickle.dumps(model) + restored = cloudpickle.loads(blob) + + assert restored is not model + assert isinstance(restored, RLSSM) + assert restored.n_participants == model.n_participants + assert restored.n_trials == model.n_trials + assert restored.list_params == model.list_params + assert restored.model_config.model_name == model.model_config.model_name + assert restored.model is not None diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index 46859da4..513f4274 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -1,7 +1,8 @@ import pytest import hssm -from hssm.config import Config, ModelConfig, RLSSMConfig +from hssm.config import Config, ModelConfig +from hssm.rl import RLSSMConfig from hssm.utils import annotate_function # Define constants for repeated data structures @@ -68,6 +69,7 @@ def _dummy_ssm_logp_func(x): model_name="test_model", list_params=["alpha", "beta"], params_default=[0.5, 0.3], + bounds={"alpha": (0.0, 1.0), "beta": (0.0, 1.0)}, decision_process="ddm", response=["rt", "response"], choices=[0, 1], @@ -244,21 +246,69 @@ def test_validate_ssm_logp_func_missing_annotations(self, valid_rlssmconfig_kwar ): config.validate() + def test_validate_ssm_logp_func_computed_not_callable( + self, valid_rlssmconfig_kwargs + ): + """`computed` exists but contains non-callable values -> error.""" + config = RLSSMConfig(**valid_rlssmconfig_kwargs) + # Inject a computed mapping with a non-callable value to trigger the + # specific validation branch. + config.ssm_logp_func.computed = {"x": "not_callable"} + with pytest.raises( + ValueError, + match=r"`ssm_logp_func.computed` must be a dictionary with callable values\.", + ): + config.validate() + + def test_validate_missing_bounds_for_param(self, valid_rlssmconfig_kwargs): + """validate() should raise early when a param has no bounds entry.""" + kwargs = {**valid_rlssmconfig_kwargs, "bounds": {}} # strip all bounds + config = RLSSMConfig(**kwargs) + with pytest.raises(ValueError, match="Missing bounds for parameter"): + config.validate() + + def test_from_defaults_raises(self): + """RLSSMConfig.from_defaults() must raise NotImplementedError.""" + with pytest.raises(NotImplementedError, match="from_defaults"): + RLSSMConfig.from_defaults("ddm", None) + class TestRLSSMConfigDefaults: @pytest.mark.parametrize( "list_params, params_default, bounds, param, expected_default, expected_bounds", [ + # params_default stores initialisation values, NOT priors. + # get_defaults always returns None for the prior so that + # prior_settings="safe" can assign priors from bounds. + # + # Case 1: queried param is present in bounds → bound returned. ( ["alpha", "beta", "gamma"], [0.5, 0.3, 0.2], - {"beta": (0.0, 1.0)}, + {"alpha": (0.0, 1.0), "beta": (0.0, 1.0), "gamma": (0.0, 1.0)}, "beta", - 0.3, + None, + (0.0, 1.0), + ), + # Case 2: queried param is NOT in list_params and NOT in bounds + # (e.g. a typo or an extra lookup) → both None. + ( + ["alpha", "beta"], + [0.5, 0.3], + {"alpha": (0.0, 1.0), "beta": (0.0, 1.0)}, + "gamma", + None, + None, + ), + # Case 3: params_default may be empty; param in bounds → bound returned. + ( + ["alpha", "beta"], + [], + {"alpha": (0.0, 1.0), "beta": (0.0, 1.0)}, + "alpha", + None, (0.0, 1.0), ), - (["alpha", "beta"], [0.5, 0.3], {"alpha": (0.0, 1.0)}, "gamma", None, None), - (["alpha", "beta"], [], {"alpha": (0.0, 1.0)}, "alpha", None, (0.0, 1.0)), ], ) def test_get_defaults_cases( @@ -287,145 +337,6 @@ def test_get_defaults_cases( assert bounds_val == expected_bounds -class TestRLSSMConfigConversion: - @pytest.mark.parametrize( - "list_params, params_default, backend, expected_backend, expected_default_priors, raises", - [ - ( - ["alpha", "beta", "v", "a"], - [0.5, 0.3, 1.0, 1.5], - "jax", - "jax", - {"alpha": 0.5, "beta": 0.3, "v": 1.0, "a": 1.5}, - None, - ), - (["alpha"], [0.5], None, "jax", {"alpha": 0.5}, None), - (["alpha", "beta"], [], None, "jax", {}, None), - (["alpha", "beta", "gamma"], [0.5, 0.3], None, None, None, ValueError), - ], - ) - def test_to_config_cases( - self, - list_params, - params_default, - backend, - expected_backend, - expected_default_priors, - raises, - ): - model_config = RLSSMConfig( - model_name="test_model", - list_params=list_params, - params_default=params_default, - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - backend=backend, - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - if raises: - with pytest.raises(raises): - model_config.to_config() - else: - config = model_config.to_config() - assert config.backend == expected_backend - assert config.default_priors == expected_default_priors - - def test_to_config(self): - model_config = RLSSMConfig( - model_name="rlwm", - description="RLWM model", - list_params=["alpha", "beta", "v", "a"], - params_default=[0.5, 0.3, 1.0, 1.5], - bounds={ - "alpha": (0.0, 1.0), - "beta": (0.0, 1.0), - "v": (-3.0, 3.0), - "a": (0.3, 2.5), - }, - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - extra_fields=["feedback"], - backend="jax", - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - config = model_config.to_config() - assert isinstance(config, Config) - assert config.model_name == "rlwm" - assert config.description == "RLWM model" - assert config.list_params == ["alpha", "beta", "v", "a"] - assert config.response == ["rt", "response"] - assert config.choices == [0, 1] - assert config.extra_fields == ["feedback"] - assert config.backend == "jax" - assert config.loglik_kind == "approx_differentiable" - assert config.bounds == { - "alpha": (0.0, 1.0), - "beta": (0.0, 1.0), - "v": (-3.0, 3.0), - "a": (0.3, 2.5), - } - assert config.default_priors == { - "alpha": 0.5, - "beta": 0.3, - "v": 1.0, - "a": 1.5, - } - - def test_to_config_defaults_backend(self): - model_config = RLSSMConfig( - model_name="test_model", - list_params=["alpha"], - params_default=[0.5], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - config = model_config.to_config() - assert config.backend == "jax" - - def test_to_config_no_defaults(self): - model_config = RLSSMConfig( - model_name="test_model", - list_params=["alpha", "beta"], - params_default=[], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - config = model_config.to_config() - assert config.default_priors == {} - - def test_to_config_mismatched_defaults_length(self): - model_config = RLSSMConfig( - model_name="test_model", - list_params=["alpha", "beta", "gamma"], - params_default=[0.5, 0.3], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - with pytest.raises( - ValueError, - match=r"params_default length \(2\) doesn't match list_params length \(3\)", - ): - model_config.to_config() - - class TestRLSSMConfigLearningProcess: def test_learning_process(self): config = RLSSMConfig(