Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
aec1531
Refactor RLSSMConfig methods to simplify parameter handling and remov…
cpaniaguam Mar 18, 2026
a8cd51d
Fix handling of list_params in HSSMBase to ensure proper conversion f…
cpaniaguam Mar 18, 2026
9c22e26
Refactor RLSSM to inject model configuration directly, removing unnec…
cpaniaguam Mar 18, 2026
5658834
Update TestRLSSMConfigDefaults to reflect None for default parameters…
cpaniaguam Mar 18, 2026
7a294af
Refactor RLSSM to inject loglik and backend directly into a new RLSSM…
cpaniaguam Mar 18, 2026
fd99efb
Add validation for missing bounds in RLSSMConfig parameters
cpaniaguam Mar 18, 2026
bc0f7ca
Fix RLSSM to use model_config for ssm_logp_func and update test cases…
cpaniaguam Mar 18, 2026
b075e4f
Enhance RLSSM tests to align params_is_trialwise with list_params and…
cpaniaguam Mar 18, 2026
27d505e
Add test to ensure RLSSMConfig.from_defaults raises NotImplementedError
cpaniaguam Mar 18, 2026
ce8e187
Clarify RLSSMConfig.from_defaults behavior and raise NotImplementedEr…
cpaniaguam Mar 18, 2026
7c7fd32
Inject JAX backend into RLSSMConfig during initialization
cpaniaguam Mar 18, 2026
582a6fe
Merge branch '930-pass-configs-via-dependency-injection-into-model-cl…
cpaniaguam Mar 19, 2026
a3898d7
Fix merge conflicts with base branch
cpaniaguam Mar 19, 2026
4d99410
Remove commented out lines
cpaniaguam Mar 19, 2026
f04f47e
Remove RLSSMConfig import from __init__.py
cpaniaguam Mar 25, 2026
11115af
Reorganize import statements by moving RLSSMConfig import to the corr…
cpaniaguam Mar 25, 2026
6a9384f
Move RLSSMConfig import to the correct module in test files
cpaniaguam Mar 25, 2026
0285f04
Update docstring in __init__.py and exports
cpaniaguam Mar 25, 2026
5807a71
Remove RLSSMConfig class and its associated methods from config.py
cpaniaguam Mar 25, 2026
4bf67ea
Move RLSSMConfig class hssm.rl module
cpaniaguam Mar 25, 2026
5d74bfe
Refactor config.py to remove RLSSM-specific defaults and unify observ…
cpaniaguam Mar 27, 2026
91b1098
Enhance validation in RLSSMConfig for ssm_logp_func attributes
cpaniaguam Mar 27, 2026
c3a4f52
Add validation test for non-callable values in ssm_logp_func.computed
cpaniaguam Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/hssm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
import sys

from .config import ModelConfig, RLSSMConfig
from .config import ModelConfig
from .datasets import load_data
from .defaults import show_defaults
from .hssm import HSSM
Expand All @@ -33,7 +33,6 @@
__all__ = [
"HSSM",
"RLSSM",
"RLSSMConfig",
"Link",
"load_data",
"ModelConfig",
Expand Down
6 changes: 5 additions & 1 deletion src/hssm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,11 @@ def __init__(
if self.model_config.response is not None
else None
)
self.list_params = self.model_config.list_params
self.list_params = (
list(self.model_config.list_params)
if self.model_config.list_params is not None
else None
)
self.choices = self.model_config.choices # type: ignore[assignment]
self.model_name = self.model_config.model_name
self.loglik = self.model_config.loglik
Expand Down
289 changes: 2 additions & 287 deletions src/hssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,10 @@
_logger = logging.getLogger("hssm")


# ====== Centralized RLSSM defaults =====
# ====== Centralized SSM defaults =====
DEFAULT_SSM_OBSERVED_DATA = ["rt", "response"]
DEFAULT_RLSSM_OBSERVED_DATA = ["rt", "response"]
DEFAULT_SSM_CHOICES = (0, 1)

RLSSM_REQUIRED_FIELDS = (
"model_name",
"description",
"list_params",
"bounds",
"params_default",
"choices",
"decision_process",
"learning_process",
"response",
"decision_process_loglik_kind",
"learning_process_loglik_kind",
"extra_fields",
"ssm_logp_func",
)

ParamSpec = Union[float, dict[str, Any], Prior, None]


Expand Down Expand Up @@ -196,7 +179,7 @@ def from_defaults(
return Config(
model_name=model_name,
loglik_kind=loglik_kind,
response=DEFAULT_RLSSM_OBSERVED_DATA,
response=DEFAULT_SSM_OBSERVED_DATA,
)

def update_loglik(self, loglik: Any | None) -> None:
Expand Down Expand Up @@ -324,274 +307,6 @@ def _build_model_config(
return config


@dataclass
class RLSSMConfig(BaseModelConfig):
"""Config for reinforcement learning + sequential sampling models.

This configuration class is designed for models that combine reinforcement
learning processes with sequential sampling decision models (RLSSM).

The ``ssm_logp_func`` field holds the fully annotated JAX SSM log-likelihood
function (an :class:`AnnotatedFunction`) that is passed directly to
``make_rl_logp_op``. It supersedes the ``loglik`` / ``loglik_kind`` workflow
used by :class:`HSSM`: the Op is built from ``ssm_logp_func`` and therefore
no ``loglik`` callable needs to be provided.
"""

decision_process_loglik_kind: str = field(kw_only=True)
learning_process_loglik_kind: str = field(kw_only=True)
params_default: list[float] = field(kw_only=True)
decision_process: str | ModelConfig = field(kw_only=True)
learning_process: dict[str, Any] = field(kw_only=True)
# The fully annotated SSM log-likelihood function used by make_rl_logp_op.
# Type is Any to avoid a hard dependency on the AnnotatedFunction Protocol at
# import time; validated at runtime in validate().
ssm_logp_func: Any = field(default=None, kw_only=True)

def __post_init__(self):
"""Set default loglik_kind for RLSSM models if not provided."""
if self.loglik_kind is None:
self.loglik_kind = "approx_differentiable"

@classmethod
def from_defaults(
cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None
) -> Config:
"""Return the shared Config defaults (delegated to :class:`Config`)."""
return Config.from_defaults(model_name, loglik_kind)

@classmethod
def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> RLSSMConfig:
"""
Create RLSSMConfig from a configuration dictionary.

Parameters
----------
config_dict : dict[str, Any]
Dictionary containing model configuration. Expected keys:
- model_name: Model identifier (required)
- description: Model description (required)
- list_params: List of parameter names (required)
- extra_fields: List of extra field names from data (required)
- params_default: Default parameter values (required)
- bounds: Parameter bounds (required)
- response: Response column names (required)
- choices: Valid choice values (required)
- decision_process: Decision process specification (required)
- learning_process: Learning process functions (required)
- decision_process_loglik_kind: Likelihood kind for decision process
(required)
- learning_process_loglik_kind: Likelihood kind for learning process
(required)

Returns
-------
RLSSMConfig
Configured RLSSM model configuration object.
"""
# Check for required fields and raise explicit errors if missing
for field_name in RLSSM_REQUIRED_FIELDS:
if field_name not in config_dict or config_dict[field_name] is None:
raise ValueError(f"{field_name} must be provided in config_dict")

return cls(
model_name=config_dict["model_name"],
description=config_dict["description"],
list_params=config_dict["list_params"],
extra_fields=config_dict.get("extra_fields"),
params_default=config_dict["params_default"],
decision_process=config_dict["decision_process"],
learning_process=config_dict["learning_process"],
ssm_logp_func=config_dict["ssm_logp_func"],
bounds=config_dict.get("bounds", {}),
response=config_dict["response"],
choices=config_dict["choices"],
decision_process_loglik_kind=config_dict["decision_process_loglik_kind"],
learning_process_loglik_kind=config_dict["learning_process_loglik_kind"],
)

def validate(self) -> None:
"""Validate RLSSM configuration.

Raises
------
ValueError
If required fields are missing or inconsistent.
"""
if self.response is None:
raise ValueError("Please provide `response` columns in the configuration.")
if self.list_params is None:
raise ValueError("Please provide `list_params` in the configuration.")
if self.choices is None:
raise ValueError("Please provide `choices` in the configuration.")
if self.decision_process is None:
raise ValueError("Please specify a `decision_process`.")
if self.ssm_logp_func is None:
raise ValueError(
"Please provide `ssm_logp_func`: the fully annotated JAX SSM "
"log-likelihood function required by `make_rl_logp_op`."
)
if not callable(self.ssm_logp_func):
raise ValueError(
"`ssm_logp_func` must be a callable, "
f"but got {type(self.ssm_logp_func)!r}."
)
missing_attrs = [
attr
for attr in ("inputs", "outputs", "computed")
if not hasattr(self.ssm_logp_func, attr)
]
if missing_attrs:
raise ValueError(
"`ssm_logp_func` must be decorated with `@annotate_function` "
"so that it carries the attributes required by `make_rl_logp_op`. "
f"Missing attribute(s): {missing_attrs}. "
"Decorate the function like:\n\n"
" @annotate_function(\n"
" inputs=[...], outputs=[...], computed={...}\n"
" )\n"
" def my_ssm_logp(lan_matrix): ..."
)

# Validate parameter defaults consistency
if self.params_default and self.list_params:
if len(self.params_default) != len(self.list_params):
raise ValueError(
f"params_default length ({len(self.params_default)}) doesn't "
f"match list_params length ({len(self.list_params)})"
)

def get_defaults(
self, param: str
) -> tuple[float | None, tuple[float, float] | None]:
"""Return default value and bounds for a parameter.

Parameters
----------
param
The name of the parameter.

Returns
-------
tuple
A tuple of (default_value, bounds) where:
- default_value is a float or None if not found
- bounds is a tuple (lower, upper) or None if not found
"""
# Try to find the parameter in list_params and get its default value
default_val = None
if self.list_params is not None:
try:
param_idx = self.list_params.index(param)
if self.params_default and param_idx < len(self.params_default):
default_val = self.params_default[param_idx]
except ValueError:
# Parameter not in list_params
pass

return default_val, self.bounds.get(param)

def to_config(self) -> Config:
"""Convert to standard Config for compatibility with HSSM.

This method transforms the RLSSM configuration into a standard Config
object that can be used with the existing HSSM infrastructure.

Returns
-------
Config
A Config object with RLSSM parameters mapped to standard format.

Notes
-----
The transformation converts params_default list to default_priors dict,
mapping parameter names to their default values.
"""
# Validate parameter defaults consistency before conversion
if self.params_default and self.list_params:
if len(self.params_default) != len(self.list_params):
raise ValueError(
f"params_default length ({len(self.params_default)}) doesn't "
f"match list_params length ({len(self.list_params)}). "
"This would result in silent data loss during conversion."
)

# Transform params_default list to default_priors dict
default_priors = (
{
param: default
for param, default in zip(self.list_params, self.params_default)
}
if self.list_params and self.params_default
else {}
)

return Config(
model_name=self.model_name,
loglik_kind=self.loglik_kind,
response=self.response,
choices=self.choices,
list_params=self.list_params,
description=self.description,
bounds=self.bounds,
default_priors=cast(
"dict[str, float | dict[str, Any] | Any | None]", default_priors
),
extra_fields=self.extra_fields,
backend=self.backend or "jax", # RLSSM typically uses JAX
loglik=self.loglik,
)

def to_model_config(self) -> ModelConfig:
"""Build a :class:`ModelConfig` from this :class:`RLSSMConfig`.

All fields are sourced from ``self``; the backend is fixed to ``"jax"``
because RLSSM exclusively uses the JAX backend.

``default_priors`` is intentionally left empty so the
``prior_settings="safe"`` mechanism in :class:`~hssm.base.HSSMBase`
assigns sensible priors from bounds rather than fixing every parameter
to a constant scalar.
"""
return ModelConfig(
response=tuple(self.response), # type: ignore[arg-type]
list_params=list(self.list_params), # type: ignore[arg-type]
choices=tuple(self.choices), # type: ignore[arg-type]
default_priors={},
bounds=self.bounds,
extra_fields=self.extra_fields,
backend="jax",
)

def _build_model_config(self, loglik_op: Any) -> Config:
"""Build a validated :class:`Config` for use by :class:`~hssm.rl.rlssm.RLSSM`.

Converts this :class:`RLSSMConfig` to a :class:`ModelConfig`, then
delegates to :meth:`Config._build_model_config` using the pre-built
differentiable Op as ``loglik``.

Parameters
----------
loglik_op
The differentiable pytensor Op produced by
:func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`.

Returns
-------
Config
A fully validated :class:`Config` ready to pass to
:meth:`~hssm.base.HSSMBase.__init__`.
"""
mc = self.to_model_config()
return Config._build_model_config(
self.model_name,
"approx_differentiable",
mc,
None,
loglik_op,
)


@dataclass
class ModelConfig:
"""Representation for model_config provided by the user."""
Expand Down
23 changes: 16 additions & 7 deletions src/hssm/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
"""Reinforcement learning extensions for HSSM.
"""Reinforcement-learning extensions for HSSM.

This sub-package provides:
This subpackage groups components that integrate reinforcement-learning
learning rules with sequential-sampling decision models (SSMs).

Public API (import from ``hssm.rl``):

- ``RLSSM``: the RL + SSM model class implemented in :mod:`hssm.rl.rlssm`.
- ``RLSSMConfig``: the config class for RL + SSM models, implemented in
:mod:`hssm.rl.config`.
- ``validate_balanced_panel``: panel-balance utility in :mod:`hssm.rl.utils`.

RL likelihood builders live in :mod:`hssm.rl.likelihoods.builder` and include
helpers such as :func:`~hssm.rl.likelihoods.builder.make_rl_logp_func` and
:func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`.

- :class:`~hssm.rl.rlssm.RLSSM` — the RL + SSM model class.
- :func:`~hssm.rl.utils.validate_balanced_panel` — panel-balance utility.
- :mod:`hssm.rl.likelihoods` — log-likelihood builders
(:func:`~hssm.rl.likelihoods.builder.make_rl_logp_func`,
:func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`).
"""

from .config import RLSSMConfig
from .rlssm import RLSSM
from .utils import validate_balanced_panel

__all__ = [
"RLSSM",
"RLSSMConfig",
"validate_balanced_panel",
]
Loading