Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
20ddc4c
Add ssm_logp_func to RLSSMConfig and update validation tests
cpaniaguam Mar 2, 2026
d97dcee
Add RLSSM model and utilities for reinforcement learning integration
cpaniaguam Mar 2, 2026
a6a0238
Refactor RLSSM parameter handling and add custom prefix resolution fo…
cpaniaguam Mar 2, 2026
d880977
Add tests for RLSSM class covering initialization, validation, and mo…
cpaniaguam Mar 2, 2026
bef8d6c
Refactor loglik handling in RLSSM to improve type safety with casting
cpaniaguam Mar 2, 2026
3981ef6
Add NaN value check for participant column in validate_balanced_panel…
cpaniaguam Mar 2, 2026
d84a800
Add validation for ssm_logp_func in RLSSMConfig to ensure it is calla…
cpaniaguam Mar 2, 2026
15ad6e2
Add exclude rules for ruff and mypy hooks to skip tests directory
cpaniaguam Mar 2, 2026
262ec07
Add validation tests for ssm_logp_func in RLSSMConfig to ensure it is…
cpaniaguam Mar 2, 2026
381275a
Add tests for NaN participant_id and unannotated ssm_logp_func in RLSSM
cpaniaguam Mar 2, 2026
0e9ba42
Reject missing data and deadline handling in RLSSM initialization to …
cpaniaguam Mar 2, 2026
4f28c68
Add tests to validate error handling for missing data and deadline in…
cpaniaguam Mar 2, 2026
5e9f566
Refactor path handling for loading RLDM fixture dataset in tests
cpaniaguam Mar 2, 2026
67ac2ce
Add fixture to set floatX to float32 for module tests
cpaniaguam Mar 2, 2026
e1c05df
Ensure params_is_trialwise aligns with list_params in RLSSM initializ…
cpaniaguam Mar 2, 2026
564232b
Clarify comments on default_priors in ModelConfig and remove unnecess…
cpaniaguam Mar 2, 2026
bafc037
Update RLSSM to use to_numpy(copy=True) for extra_fields and add test…
cpaniaguam Mar 2, 2026
ba358a4
Refactor parameter name resolution in RLSSM to handle underscores cor…
cpaniaguam Mar 2, 2026
0bfa755
Add test for _get_prefix method in RLSSM to ensure token-based matching
cpaniaguam Mar 2, 2026
5b8a16a
Refactor RLSSMConfig.from_rlssm_dict to remove model_name parameter a…
cpaniaguam Mar 2, 2026
f69f2b6
Fix comment in test_rlssm.py to clarify output shape of log-likelihoo…
cpaniaguam Mar 3, 2026
bad943d
Update RLSSMConfig documentation to mark description as required
cpaniaguam Mar 3, 2026
241aad2
Add ssm_logp_func to RLSSM_REQUIRED_FIELDS and update RLSSMConfig ini…
cpaniaguam Mar 3, 2026
ca3816d
Add dummy ssm_logp_func to tests and validate its presence in RLSSMCo…
cpaniaguam Mar 3, 2026
827025c
Remove unused logging import from rlssm.py
cpaniaguam Mar 3, 2026
292d6f0
Remove redundant exclude rule for ruff-format in pre-commit configura…
cpaniaguam Mar 3, 2026
3b1aaf4
Add to_model_config method to RLSSMConfig for ModelConfig conversion
cpaniaguam Mar 4, 2026
0678c45
Refactor RLSSM to delegate ModelConfig construction to RLSSMConfig an…
cpaniaguam Mar 4, 2026
26a9336
Integrate Config and RLSSMConfig into HSSM and RLSSM classes for impr…
cpaniaguam Mar 9, 2026
cd660da
Update choices type from list to tuple for consistency in BaseModelCo…
cpaniaguam Mar 9, 2026
db308c6
Update choices type from list to tuple in test_constructor for consis…
cpaniaguam Mar 9, 2026
ef000cd
Fix formatting of error messages in TestRLSSMConfigValidation for con…
cpaniaguam Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ repos:
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
exclude: ^tests/
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.19.1 # Use the sha / tag you want to point at
hooks:
- id: mypy
args: [--no-strict-optional, --ignore-missing-imports]
exclude: ^tests/
5 changes: 4 additions & 1 deletion src/hssm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
import logging
import sys

from .config import ModelConfig
from .config import ModelConfig, RLSSMConfig
from .datasets import load_data
from .defaults import show_defaults
from .hssm import HSSM
from .link import Link
from .param import UserParam as Param
from .prior import Prior
from .register import register_model
from .rl import RLSSM
from .simulator import simulate_data
from .utils import check_data_for_rl, set_floatX

Expand All @@ -31,6 +32,8 @@

__all__ = [
"HSSM",
"RLSSM",
"RLSSMConfig",
"Link",
"load_data",
"ModelConfig",
Expand Down
5 changes: 2 additions & 3 deletions src/hssm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,6 @@ class HSSMBase(ABC, DataValidatorMixin, MissingDataMixin):
The jitter value for the initial values.
"""

config_class = Config

def __init__(
self,
data: pd.DataFrame,
Expand Down Expand Up @@ -529,7 +527,8 @@ def _build_model_config(
A complete Config object with choices and other settings applied.
"""
# Start with defaults
config = cls.config_class.from_defaults(model, loglik_kind)
# get_config_class is provided by Config/RLSSMConfig mixin through MRO
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does RLSSMConfig show up here in this file?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this will be cleaned up after #936 and #931 get merged into their respective base branches.

config = cls.get_config_class().from_defaults(model, loglik_kind) # type: ignore[attr-defined]

# Handle user-provided model_config
if model_config is not None:
Expand Down
93 changes: 89 additions & 4 deletions src/hssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"decision_process_loglik_kind",
"learning_process_loglik_kind",
"extra_fields",
"ssm_logp_func",
)

ParamSpec = Union[float, dict[str, Any], Prior, None]
Expand Down Expand Up @@ -67,6 +68,16 @@ class BaseModelConfig(ABC):
# Additional data requirements
extra_fields: list[str] | None = None

@classmethod
@abstractmethod
def get_config_class(cls) -> type["BaseModelConfig"]:
"""Return the config class for this model type.

This enables polymorphic config resolution without circular imports.
Each subclass returns itself as the config class.
"""
...

@abstractmethod
def validate(self) -> None:
"""Validate configuration. Must be implemented by subclasses."""
Expand All @@ -91,6 +102,11 @@ def __post_init__(self):
if self.loglik_kind is None:
raise ValueError("loglik_kind is required for Config")

@classmethod
def get_config_class(cls) -> type["Config"]:
"""Return Config as the config class for HSSM models."""
return Config

@classmethod
def from_defaults(
cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None
Expand Down Expand Up @@ -266,19 +282,41 @@ class RLSSMConfig(BaseModelConfig):

This configuration class is designed for models that combine reinforcement
learning processes with sequential sampling decision models (RLSSM).

The ``ssm_logp_func`` field holds the fully annotated JAX SSM log-likelihood
function (an :class:`AnnotatedFunction`) that is passed directly to
``make_rl_logp_op``. It supersedes the ``loglik`` / ``loglik_kind`` workflow
used by :class:`HSSM`: the Op is built from ``ssm_logp_func`` and therefore
no ``loglik`` callable needs to be provided.
"""

decision_process_loglik_kind: str = field(kw_only=True)
learning_process_loglik_kind: str = field(kw_only=True)
params_default: list[float] = field(kw_only=True)
decision_process: str | ModelConfig = field(kw_only=True)
learning_process: dict[str, Any] = field(kw_only=True)
# The fully annotated SSM log-likelihood function used by make_rl_logp_op.
# Type is Any to avoid a hard dependency on the AnnotatedFunction Protocol at
# import time; validated at runtime in validate().
ssm_logp_func: Any = field(default=None, kw_only=True)

def __post_init__(self):
"""Set default loglik_kind for RLSSM models if not provided."""
if self.loglik_kind is None:
self.loglik_kind = "approx_differentiable"

@classmethod
def get_config_class(cls) -> type["RLSSMConfig"]:
"""Return RLSSMConfig as the config class for RLSSM models."""
return RLSSMConfig

@classmethod
def from_defaults(
cls, model_name: SupportedModels | str, loglik_kind: LoglikKind | None
) -> Config:
"""Return the shared Config defaults (delegated to :class:`Config`)."""
return Config.from_defaults(model_name, loglik_kind)

@property
def n_params(self) -> int | None:
"""Return the number of parameters."""
Expand All @@ -290,17 +328,16 @@ def n_extra_fields(self) -> int | None:
return len(self.extra_fields) if self.extra_fields else None

@classmethod
def from_rlssm_dict(cls, model_name: str, config_dict: dict[str, Any]):
def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> "RLSSMConfig":
"""
Create RLSSMConfig from a configuration dictionary.

Parameters
----------
model_name : str
The name of the RLSSM model.
config_dict : dict[str, Any]
Dictionary containing model configuration. Expected keys:
- description: Model description (optional)
- model_name: Model identifier (required)
- description: Model description (required)
- list_params: List of parameter names (required)
- extra_fields: List of extra field names from data (required)
- params_default: Default parameter values (required)
Expand Down Expand Up @@ -332,6 +369,7 @@ def from_rlssm_dict(cls, model_name: str, config_dict: dict[str, Any]):
params_default=config_dict["params_default"],
decision_process=config_dict["decision_process"],
learning_process=config_dict["learning_process"],
ssm_logp_func=config_dict["ssm_logp_func"],
bounds=config_dict.get("bounds", {}),
response=config_dict["response"],
choices=config_dict["choices"],
Expand All @@ -355,6 +393,32 @@ def validate(self) -> None:
raise ValueError("Please provide `choices` in the configuration.")
if self.decision_process is None:
raise ValueError("Please specify a `decision_process`.")
if self.ssm_logp_func is None:
raise ValueError(
"Please provide `ssm_logp_func`: the fully annotated JAX SSM "
"log-likelihood function required by `make_rl_logp_op`."
)
if not callable(self.ssm_logp_func):
raise ValueError(
"`ssm_logp_func` must be a callable, "
f"but got {type(self.ssm_logp_func)!r}."
)
missing_attrs = [
attr
for attr in ("inputs", "outputs", "computed")
if not hasattr(self.ssm_logp_func, attr)
]
if missing_attrs:
raise ValueError(
"`ssm_logp_func` must be decorated with `@annotate_function` "
"so that it carries the attributes required by `make_rl_logp_op`. "
f"Missing attribute(s): {missing_attrs}. "
"Decorate the function like:\n\n"
" @annotate_function(\n"
" inputs=[...], outputs=[...], computed={...}\n"
" )\n"
" def my_ssm_logp(lan_matrix): ..."
)

# Validate parameter defaults consistency
if self.params_default and self.list_params:
Expand Down Expand Up @@ -445,6 +509,27 @@ def to_config(self) -> Config:
loglik=self.loglik,
)

def to_model_config(self) -> "ModelConfig":
"""Build a :class:`ModelConfig` from this :class:`RLSSMConfig`.

All fields are sourced from ``self``; the backend is fixed to ``"jax"``
because RLSSM exclusively uses the JAX backend.

``default_priors`` is intentionally left empty so the
``prior_settings="safe"`` mechanism in :class:`~hssm.base.HSSMBase`
assigns sensible priors from bounds rather than fixing every parameter
to a constant scalar.
"""
return ModelConfig(
response=tuple(self.response), # type: ignore[arg-type]
list_params=list(self.list_params), # type: ignore[arg-type]
choices=tuple(self.choices), # type: ignore[arg-type]
default_priors={},
bounds=self.bounds,
extra_fields=self.extra_fields,
backend="jax",
)


@dataclass
class ModelConfig:
Expand Down
4 changes: 2 additions & 2 deletions src/hssm/data_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DataValidatorMixin:
This class expects subclasses to define the following attributes:
- data: pd.DataFrame
- response: list[str]
- choices: list[int]
- choices: list[int] | tuple[int, ...]
- n_choices: int
- extra_fields: list[str] | None
- deadline: bool
Expand All @@ -30,7 +30,7 @@ def __init__(
self,
data: pd.DataFrame,
response: list[str] | None = ["rt", "response"],
choices: list[int] | None = [0, 1],
choices: tuple[int, ...] | None = (0, 1),
n_choices: int = 2,
extra_fields: list[str] | None = None,
deadline: bool = False,
Expand Down
3 changes: 2 additions & 1 deletion src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)

from .base import HSSMBase
from .config import Config

if TYPE_CHECKING:
from os import PathLike
Expand Down Expand Up @@ -74,7 +75,7 @@ def __get__(self, instance, owner): # noqa: D105
return self.fget(owner)


class HSSM(HSSMBase):
class HSSM(HSSMBase, Config):
"""The basic Hierarchical Sequential Sampling Model (HSSM) class.

Parameters
Expand Down
18 changes: 18 additions & 0 deletions src/hssm/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Reinforcement learning extensions for HSSM.

This sub-package provides:

- :class:`~hssm.rl.rlssm.RLSSM` — the RL + SSM model class.
- :func:`~hssm.rl.utils.validate_balanced_panel` — panel-balance utility.
- :mod:`hssm.rl.likelihoods` — log-likelihood builders
(:func:`~hssm.rl.likelihoods.builder.make_rl_logp_func`,
:func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`).
"""

from .rlssm import RLSSM
from .utils import validate_balanced_panel

__all__ = [
"RLSSM",
"validate_balanced_panel",
]
Loading