Skip to content

Inject rlssm config directly into hssm base#936

Open
cpaniaguam wants to merge 23 commits into930-pass-configs-via-dependency-injection-into-model-classes-basemodelconfig-only-dict-supportedfrom
inject-RLSSMConfig-directly-into-HSSMBase
Open

Inject rlssm config directly into hssm base#936
cpaniaguam wants to merge 23 commits into930-pass-configs-via-dependency-injection-into-model-classes-basemodelconfig-only-dict-supportedfrom
inject-RLSSMConfig-directly-into-HSSMBase

Conversation

@cpaniaguam
Copy link
Copy Markdown
Collaborator

@cpaniaguam cpaniaguam commented Mar 18, 2026

This pull request makes several important improvements to the RLSSM configuration and model pipeline in the hssm package. The changes clarify the intended use of RLSSM configuration objects, enforce stricter validation (especially for parameter bounds), simplify how parameter defaults and priors are handled, and update the RLSSM model to use its configuration more robustly. Additionally, new tests are added to ensure correct behavior and serialization.

Configuration and Validation Improvements:

  • RLSSMConfig.from_defaults now raises NotImplementedError to prevent accidental misuse, clarifying that RLSSM models must be constructed via from_rlssm_dict or the constructor, not via shared defaults.
  • The validate() method in RLSSMConfig now checks that every parameter in list_params has a corresponding entry in bounds, raising a clear error if not, which helps catch configuration mistakes early.
  • The get_defaults() method in RLSSMConfig is simplified: it always returns (None, bounds) for a parameter, ensuring that priors are assigned from bounds rather than fixed values, and removing unnecessary conversion methods.

RLSSM Model Pipeline Updates:

  • The RLSSM model now stores the RLSSM-specific config (RLSSMConfig) directly as model_config (with the differentiable Op and backend injected), rather than converting it to a generic Config, ensuring all RLSSM-specific fields are preserved and used consistently. [1] [2] [3]
  • The model's _make_model_distribution method now uses the correct list_params and accesses the log-likelihood Op directly from the RLSSM config, improving clarity and type safety. [1] [2] [3]

Testing Enhancements:

  • New tests are added to check for missing bounds in validation, ensure from_defaults raises as intended, verify get_defaults behavior, and confirm RLSSM objects can be correctly serialized and deserialized with cloudpickle. [1] [2]

Other Notable Changes:

  • The RLSSM config is now always copied and updated immutably when the log-likelihood Op and backend are set, preventing accidental mutation of user-supplied config objects.
  • Minor code cleanups and improved docstrings clarify the distinction between parameter defaults (for initialization) and priors (for inference). [1] [2]

Code Consistency and Safety:

  • Ensures list_params is always a list (not a tuple) in the base model, and that all code paths expect and enforce this invariant. [1] [2]

Overall, these changes make the RLSSM configuration and model pipeline more robust, easier to use, and safer against common misconfigurations.


Configuration and Validation Improvements

  • RLSSMConfig.from_defaults now raises NotImplementedError to prevent misuse and clarify the correct construction path for RLSSM models.
  • RLSSMConfig.validate() now checks that all parameters in list_params have corresponding bounds, raising a clear error if any are missing.
  • get_defaults() in RLSSMConfig now always returns (None, bounds) for a parameter, ensuring priors are assigned from bounds and simplifying the interface.

RLSSM Model Pipeline Updates

  • RLSSM models now store an updated RLSSMConfig (with Op and backend injected) as model_config, instead of converting to a generic Config, preserving all RLSSM-specific fields. [1] [2] [3]
  • _make_model_distribution uses the correct list_params and accesses the log-likelihood Op directly from model_config, improving clarity and correctness. [1] [2] [3]

Testing Enhancements

  • Added tests for missing bounds in validation, from_defaults raising, get_defaults cases, and cloudpickle round-trip serialization for RLSSM models. [1] [2]

Code Consistency and Safety

  • Ensures list_params is always a list, not a tuple, throughout the base and RLSSM model code. [1] [2]

Documentation and Clarity

  • Improved docstrings and inline comments to clarify the purpose and usage of RLSSM configuration fields, especially the distinction between parameter defaults and priors. [1] [2]

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors RLSSM configuration/model initialization so RLSSMConfig is used end-to-end (no conversion to Config), and adjusts default handling so RLSSM parameters are no longer implicitly “fixed” by params_default values.

Changes:

  • Updated RLSSMConfig.get_defaults() to always return None for the default prior, enabling priors to be derived from bounds instead of scalar constants.
  • Removed RLSSMConfig conversion helpers/tests and injected the built RLSSM loglik Op directly onto RLSSMConfig before passing it into HSSMBase.
  • Made HSSMBase copy model_config.list_params into self.list_params to avoid mutating the config’s list.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
tests/test_rlssm_config.py Updates get_defaults expectations; removes conversion-related tests.
src/hssm/rl/rlssm.py Passes RLSSMConfig directly to HSSMBase; injects built loglik Op and uses self.list_params for distribution building.
src/hssm/config.py Changes RLSSMConfig defaults behavior and deletes conversion methods.
src/hssm/base.py Copies list_params from config to instance to prevent config mutation.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors RLSSM configuration/model initialization so RLSSMConfig is used directly (no conversion to Config), and adjusts prior/bounds handling so “safe” priors are derived from bounds rather than fixed scalar defaults.

Changes:

  • Remove RLSSMConfig→Config conversion helpers and wire RLSSM to inject the built differentiable loglik Op directly into RLSSMConfig.
  • Update RLSSMConfig.get_defaults() to always return None for the prior value, and add early validation to require bounds for every parameter in list_params.
  • Make HSSMBase copy model_config.list_params into self.list_params to avoid mutation bleed-through; update tests accordingly.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
src/hssm/rl/rlssm.py Injects loglik Op into RLSSMConfig via dataclasses.replace, uses self.list_params for distribution building, updates typing/docs.
src/hssm/config.py Enforces bounds coverage in RLSSMConfig.validate() and changes get_defaults() to return (None, bounds) consistently.
src/hssm/base.py Copies list_params from config into self.list_params to avoid mutating caller/config state.
tests/test_rlssm_config.py Updates fixtures/expectations, adds validation test for missing bounds, removes conversion tests.
Comments suppressed due to low confidence (2)

src/hssm/rl/rlssm.py:234

  • params_is_trialwise is computed from list_params including p_outlier, but make_distribution() strips p_outlier from dist_params before applying params_is_trialwise indexing. This currently leaves an unused trailing boolean and makes the intended alignment unclear. Consider building params_is_trialwise from the params actually passed to loglik (i.e., excluding p_outlier when present) to keep lengths consistent and avoid future misalignment bugs if more scalar params are added.
        # p_outlier is a scalar mixture weight (not trialwise); every other
        # RLSSM parameter is trialwise (the Op receives one value per trial).
        params_is_trialwise = [name != "p_outlier" for name in list_params]

src/hssm/config.py:479

  • RLSSMConfig.get_defaults() is typed as returning tuple[float | None, ...], but the first element is conceptually a prior spec (and in this implementation is always None). To stay compatible with Config.get_defaults() and the Params machinery (which treats this as a prior spec), it would be clearer/safer to type this as tuple[ParamSpec | None, tuple[float, float] | None] (or similar) rather than float | None.
    def get_defaults(
        self, param: str
    ) -> tuple[float | None, tuple[float, float] | None]:
        """Return default value and bounds for a parameter.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

@cpaniaguam cpaniaguam requested a review from Copilot March 18, 2026 19:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR simplifies RLSSM configuration flow by removing legacy conversions to Config, injecting the differentiable loglik Op directly into RLSSMConfig, and tightening RLSSMConfig validation (especially around bounds) while updating tests to reflect the new semantics.

Changes:

  • Remove RLSSMConfigConfig conversion methods and update RLSSM to pass an RLSSMConfig (with injected loglik) directly into HSSMBase.
  • Add early validation that every list_params entry has a corresponding bounds entry; adjust get_defaults() to return None for the default prior so priors are derived from bounds.
  • Ensure HSSMBase works on its own copy of list_params (including lapse-driven p_outlier) and update tests, including a new cloudpickle round-trip test.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
src/hssm/config.py Tightens RLSSMConfig validation, changes defaults behavior, and removes legacy conversion APIs.
src/hssm/rl/rlssm.py Uses dataclasses.replace to inject loglik/backend into RLSSMConfig and aligns likelihood construction with HSSMBase’s self.list_params.
src/hssm/base.py Copies model_config.list_params into self.list_params to avoid mutating caller config and to safely append p_outlier.
tests/test_rlssm_config.py Updates fixtures and adds coverage for missing-bounds validation and from_defaults() raising.
tests/test_rlssm.py Updates lapse-related param alignment test and adds pickle round-trip coverage.
Comments suppressed due to low confidence (1)

src/hssm/config.py:498

  • RLSSMConfig.get_defaults() now always returns None for the prior, but its return annotation is still tuple[float | None, ...]. This is both misleading (the prior type is not a float anymore) and inconsistent with Config.get_defaults() which returns a ParamSpec | None. Consider updating the signature to return tuple[ParamSpec | None, tuple[float, float] | None] (or at least Any | None for the first element) to reflect actual usage in Params.make_param_from_defaults.
    def get_defaults(
        self, param: str
    ) -> tuple[float | None, tuple[float, float] | None]:
        """Return default value and bounds for a parameter.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

…asses-basemodelconfig-only-dict-supported' into inject-RLSSMConfig-directly-into-HSSMBase
@cpaniaguam cpaniaguam requested a review from digicosmos86 March 19, 2026 13:20
@cpaniaguam cpaniaguam marked this pull request as ready for review March 19, 2026 18:03
Copy link
Copy Markdown
Collaborator

@digicosmos86 digicosmos86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely much improved over the original. Still have a few questions though

@cpaniaguam cpaniaguam requested a review from digicosmos86 March 25, 2026 19:52
Copy link
Copy Markdown
Member

@AlexanderFengler AlexanderFengler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty much good to go, left some small comments.

"learning_process",
"response",
"decision_process_loglik_kind",
"learning_process_loglik_kind",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as per our discussions we shouldn't call this loglik_kind here for the learning process.

It's not a log likelihood what we are computing at that level, it's just a trajectory over time (trials) for one (or more) parameters of the model.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexanderFengler Looking at my notes here: is this the one parameter you said could be specified from other parameters? Also what should the correct name be? Currently in the tests, the value used is always "blackbox". @krishnbera

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logically shouldn't really by blackbox, because the learning rules we specify so far, are all analytical / differentiable.

For a name we can just call it learning_process_kind, it's just not a log likelihood, which was the misleading part.

Then if any of learning process of decision process components is blackbox --> entire function will be blackbox

If they are both differentiable the entire thing should be differentiable (but here we need to be careful about how the mix and match works in case we have jax and pytensor components).

In principle we should be able to mix analytical and approx_differentiable no problem, from a backend perspective it doesn't really matter if a function is a network or a collection of operations to represent an analytical computation.

For original construction, we need to know if it's one or the other (network/explicitly defined analytical function), because we need to call the right constructor, but ultimately it's differentiable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants