Skip to content

Rlssm class make model dist#915

Draft
cpaniaguam wants to merge 32 commits intocp-main-sbfrom
rlssm-class-make-model-dist
Draft

Rlssm class make model dist#915
cpaniaguam wants to merge 32 commits intocp-main-sbfrom
rlssm-class-make-model-dist

Conversation

@cpaniaguam
Copy link
Copy Markdown
Collaborator

@cpaniaguam cpaniaguam commented Mar 2, 2026

This pull request introduces reinforcement learning sequential sampling model (RLSSM) support to the HSSM package. It adds a new RLSSM class, supporting configuration, likelihood construction, and data validation for RL+SSM models, and refines the configuration workflow to require a fully annotated log-likelihood function. The changes also improve pre-commit configuration and update the package's public API.

Major features and changes:

1. RLSSM Model Integration

  • Added a new RLSSM class in src/hssm/rl/rlssm.py to support models that combine reinforcement learning processes with sequential sampling models. This class builds a differentiable pytensor Op from an annotated JAX log-likelihood function and enforces strict data requirements for balanced panels.
  • Introduced a utility function validate_balanced_panel in src/hssm/rl/utils.py to ensure input data forms a balanced panel, which is required for RLSSM models.

2. Configuration Enhancements

  • Extended RLSSMConfig in src/hssm/config.py to require an ssm_logp_func (an annotated JAX SSM log-likelihood function), replacing the previous loglik/loglik_kind workflow. Added runtime validation to ensure this function is callable and properly annotated. [1] [2] [3]
  • Updated from_rlssm_dict to accept a config dictionary and extract ssm_logp_func and model_name directly from it, simplifying model instantiation.

3. Public API and Package Structure

  • Registered RLSSM and RLSSMConfig in the package's public API via src/hssm/__init__.py and created a new src/hssm/rl/__init__.py for RL-related exports. [1] [2] [3]

4. Developer Experience

  • Updated .pre-commit-config.yaml to exclude the tests/ directory from ruff and mypy checks, streamlining development workflows.

@cpaniaguam cpaniaguam changed the base branch from main to cp-main-sb March 2, 2026 18:41
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds first-class RL + SSM (RLSSM) support to HSSM by introducing a new RLSSM model that builds a differentiable PyTensor Op from an annotated JAX SSM log-likelihood and plugs it into the existing distribution-building pipeline.

Changes:

  • Introduces RLSSM model class plus RL utility validate_balanced_panel.
  • Extends configuration via RLSSMConfig.ssm_logp_func and exposes RLSSM in the public API.
  • Adds test coverage for RLSSM initialization/model build and updates RLSSMConfig validation tests.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
src/hssm/rl/rlssm.py New RLSSM model implementation integrating RL likelihood Op into HSSMBase.
src/hssm/rl/utils.py Adds balanced-panel validation helper for RLSSM datasets.
src/hssm/rl/__init__.py RL subpackage exports for RLSSM and utilities.
src/hssm/config.py Adds ssm_logp_func to RLSSMConfig and validates presence.
src/hssm/__init__.py Exposes RLSSM / RLSSMConfig at top-level.
tests/test_rlssm.py New end-to-end-ish RLSSM tests (init, model build, balanced panel, smoke sampling).
tests/test_rlssm_config.py Updates RLSSMConfig tests to include the new required field.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

f"same number of trials. Observed trial counts: {dict(counts)}"
)

return int(len(counts)), int(counts.iloc[0])
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_balanced_panel only checks equal trial counts, but the RL likelihood builder reshapes the row order into (n_participants, n_trials, ...) (see make_rl_logp_func), which assumes each participant’s trials are in one contiguous block (and usually in-trial order). With interleaved participants, the panel can be “balanced” yet produce a silently incorrect likelihood. Consider validating contiguity (each participant appears in exactly one run of length n_trials) and/or sorting by participant_col (+ an optional trial_col if present) before returning (n_participants, n_trials).

Suggested change
return int(len(counts)), int(counts.iloc[0])
# Ensure that each participant's trials form a single contiguous block
# of rows of length n_trials. This is required because downstream code
# reshapes the data into (n_participants, n_trials, ...) based on row
# order, assuming no interleaving across participants.
n_trials = int(counts.iloc[0])
# Identify contiguous "blocks" of identical participant IDs.
blocks = data[participant_col].ne(data[participant_col].shift()).cumsum()
block_counts = data.groupby([participant_col, blocks]).size()
# Each participant must appear in exactly one block, and that block
# must have length n_trials.
blocks_per_participant = block_counts.groupby(level=0).size()
invalid_multi_blocks = blocks_per_participant[blocks_per_participant != 1]
invalid_block_sizes = block_counts[block_counts != n_trials]
if not invalid_multi_blocks.empty or not invalid_block_sizes.empty:
raise ValueError(
"Data must be ordered so that each participant's trials appear in "
"a single contiguous block of rows of length n_trials. "
"Participants with non-contiguous or incorrectly sized blocks "
f"were found. Consider sorting your data by '{participant_col}' "
"and, if available, by a trial index column before building the "
"RL likelihood."
)
return int(len(counts)), n_trials

Copilot uses AI. Check for mistakes.
"Please provide the correct participant column name via "
"`participant_col`."
)

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groupby(participant_col) drops NaN participant IDs by default, which can make n_participants/n_trials incorrect without an explicit error. Consider adding a check like data[participant_col].isna().any() and raising a clear ValueError if participant IDs are missing.

Suggested change
# Ensure there are no missing participant IDs, since groupby will drop NaNs
# silently, which would make n_participants / n_trials incorrect.
if data[participant_col].isna().any():
raise ValueError(
f"Column '{participant_col}' contains missing values. "
"Please fill or remove rows with missing participant IDs before "
"calling validate_balanced_panel."
)

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +242 to +246
)

# Rearrange data so missing rows come first (no-op when missing_data=False).
self.data = _rearrange_data(self.data)

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_rearrange_data(self.data) changes row order, but the RL logp Op reshapes trials purely by row order into (n_participants, n_trials, ...). If any rows are moved (e.g., when missing_data=True and rt == -999), this will break per-participant trial sequences and invalidate the RL learning dynamics. Since missing-data networks are not supported for RLSSM, consider raising an explicit error when missing_data/deadline handling is requested (or implement a participant-wise rearrangement that preserves within-subject order).

Copilot uses AI. Check for mistakes.
Comment on lines +49 to +56
counts = data.groupby(participant_col).size()
if counts.nunique() != 1:
raise ValueError(
"Data must form balanced panels: all participants must have the "
f"same number of trials. Observed trial counts: {dict(counts)}"
)

return int(len(counts)), int(counts.iloc[0])
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_balanced_panel() only checks equal trial counts via groupby().size(), but it does not validate that rows are ordered/grouped by participant. The RL likelihood builder (make_rl_logp_func) reshapes arrays with .reshape(n_participants, n_trials, -1) based purely on row order, so interleaved participant rows will silently mix subjects/trials and produce an incorrect likelihood. Consider either (a) enforcing contiguous blocks per participant (and optionally stable-sorting by participant_col + a trial index column if available) or (b) returning a sorted copy of the data and using that downstream.

Copilot uses AI. Check for mistakes.
@cpaniaguam cpaniaguam requested a review from Copilot March 2, 2026 20:17
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +249 to +275
# All RLSSM parameters are treated as trialwise: the Op expects arrays of
# length n_total_trials for every parameter, and make_distribution.logp
# broadcasts scalar / (1,)-shaped tensors up to (n_obs,) accordingly.
params_is_trialwise = [
True for param_name in self.params if param_name != "p_outlier"
]

extra_fields_data = (
None
if not self.extra_fields
else [deepcopy(self.data[field].values) for field in self.extra_fields]
)

assert self.list_params is not None, "list_params should be set"
# self.loglik was set to the pytensor Op built in __init__; cast to
# narrow the inherited union type so make_distribution's type-checker
# accepts it without a runtime penalty.
loglik_op = cast("Callable[..., Any] | Op", self.loglik)
return make_distribution(
rv=self.model_name,
loglik=loglik_op,
list_params=self.list_params,
bounds=self.bounds,
lapse=self.lapse,
extra_fields=extra_fields_data,
params_is_trialwise=params_is_trialwise,
)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

params_is_trialwise is derived from self.params (excluding p_outlier), but it is passed alongside list_params=self.list_params. If self.list_params includes p_outlier (common in HSSMBase), this makes params_is_trialwise shorter and potentially misaligned with list_params, which can cause incorrect broadcasting or length-check failures in make_distribution. Build params_is_trialwise from self.list_params in the same order, marking p_outlier as non-trialwise.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +168 to +185
# Build the differentiable pytensor Op from the annotated SSM function.
# This Op supersedes the loglik/loglik_kind workflow: it is passed as
# `loglik` to HSSMBase so Config.validate() is satisfied, and
# _make_model_distribution() uses it directly without any further wrapping.
#
# Pass copies of list_params / extra_fields so the closure inside
# make_rl_logp_func captures its own isolated list objects. HSSMBase will
# later append "p_outlier" to self.list_params (which is the SAME list
# object as `list_params` above), and that mutation must NOT be visible to
# the Op's _validate_args_length check at sampling time.
loglik_op = make_rl_logp_op(
ssm_logp_func=rlssm_config.ssm_logp_func,
n_participants=n_participants,
n_trials=n_trials,
data_cols=list(data_cols),
list_params=list(list_params),
extra_fields=list(extra_fields),
)
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RLSSM builds the Op exclusively from rlssm_config.ssm_logp_func (and its .computed metadata) but never uses rlssm_config.learning_process. Since learning_process is still a required RLSSMConfig field, this creates a confusing/fragile API where users can supply learning functions that are silently ignored (or diverge from ssm_logp_func.computed). Consider making learning_process optional/removing it from RLSSMConfig, or validating it matches (or populates) ssm_logp_func.computed so there is a single source of truth.

Copilot uses AI. Check for mistakes.
@cpaniaguam
Copy link
Copy Markdown
Collaborator Author

@krishnbera @AlexanderFengler @digicosmos86 Here is a first draft for the RLSSM class.

def __init__(
self,
data: pd.DataFrame,
rlssm_config: RLSSMConfig,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep the two classes as similar to each other as possible, so I'd prefer model_config here

data: pd.DataFrame,
rlssm_config: RLSSMConfig,
participant_col: str = "participant_id",
include: list[dict[str, Any] | Any] | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexanderFengler include is a legacy naming convention from HDDM. However, to me it's kind of confusing now. Should we deprecate this for something clearer? We can use an alias for now with a deprecation warning and completely remove it in a future release

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you call it instead here?
We would want to make that change globally not just for this class I guess.

Either way, would do that as a separate PR.

**kwargs: Any,
) -> None:
# Validate config (ensures ssm_logp_func is present, etc.)
rlssm_config.validate()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we initiate the parent class first?

# would scramble per-participant trial sequences and corrupt RL dynamics.
# Raise early so the user gets a clear message before model construction.
if missing_data is not False:
raise ValueError(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are implementing it in the future. Maybe NotImplementedError for now?

Comment on lines +187 to +201
# Build a ModelConfig so HSSMBase._build_model_config can apply the
# RLSSM-specific fields (response, list_params, choices, bounds, …).
# default_priors is an empty dict (no parameter-specific priors pre-set)
# so that the prior_settings="safe" mechanism in HSSMBase assigns
# sensible priors from bounds. Populating it with params_default scalar
# floats would fix every parameter as a constant, which is incorrect.
mc = ModelConfig(
response=(tuple(rlssm_config.response) if rlssm_config.response else None),
list_params=list_params,
choices=(tuple(rlssm_config.choices) if rlssm_config.choices else None),
default_priors={},
bounds=rlssm_config.bounds or {},
extra_fields=extra_fields if extra_fields else None,
backend="jax", # RLSSM always uses the JAX backend
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't do this tbh. The purpose of inheritance is not to funnel sub-class functionalities into base-class functionalities. Rather subclass should expand base-class functionalities through overrides

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +49 to +56
counts = data.groupby(participant_col).size()
if counts.nunique() != 1:
raise ValueError(
"Data must form balanced panels: all participants must have the "
f"same number of trials. Observed trial counts: {dict(counts)}"
)

return int(len(counts)), int(counts.iloc[0])
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_balanced_panel() will raise an IndexError on empty input because counts is empty and counts.iloc[0] is accessed. Please add an explicit empty-data check (e.g., if data.empty: raise ValueError(...)) so callers (notably RLSSM.__init__) get a clear, consistent ValueError instead of an internal indexing error.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 56 to 58
response: list[str] | None = field(default_factory=DEFAULT_SSM_OBSERVED_DATA.copy)
choices: tuple[int, ...] | None = DEFAULT_SSM_CHOICES
choices: list[int] | tuple[int, ...] | None = DEFAULT_SSM_CHOICES

Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

choices was widened to list[int] | tuple[int, ...], but Config.update_choices() is still typed/docs as tuple[int, ...] and _build_model_config can pass a list[int]. To avoid inconsistent public typing (and future mypy confusion), consider updating update_choices (and any related docstrings/types) to accept list[int] | tuple[int, ...] and optionally normalize to a single internal representation (e.g., always store a tuple).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@cpaniaguam cpaniaguam requested a review from digicosmos86 March 9, 2026 20:46
Copy link
Copy Markdown
Member

@AlexanderFengler AlexanderFengler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small comments.

data: pd.DataFrame,
rlssm_config: RLSSMConfig,
participant_col: str = "participant_id",
include: list[dict[str, Any] | Any] | None = None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you call it instead here?
We would want to make that change globally not just for this class I guess.

Either way, would do that as a separate PR.

)
if deadline is not False:
raise ValueError(
"RLSSM does not support `deadline` handling. "
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krishnbera do we actually have a solution for this?

"""
# Start with defaults
config = cls.config_class.from_defaults(model, loglik_kind)
# get_config_class is provided by Config/RLSSMConfig mixin through MRO
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does RLSSMConfig show up here in this file?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this will be cleaned up after #936 and #931 get merged into their respective base branches.

"decision_process": "ddm",
"learning_process": {},
"learning_process_loglik_kind": "blackbox",
"decision_process_loglik_kind": "analytical",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

learning_process_loglik_kind not a valid concept.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants