Inject rlssm config directly into hssm base#936
Conversation
…e unused conversion tests
…essary Config conversion
… instead of fixed values
There was a problem hiding this comment.
Pull request overview
Refactors RLSSM configuration/model initialization so RLSSMConfig is used end-to-end (no conversion to Config), and adjusts default handling so RLSSM parameters are no longer implicitly “fixed” by params_default values.
Changes:
- Updated
RLSSMConfig.get_defaults()to always returnNonefor the default prior, enabling priors to be derived from bounds instead of scalar constants. - Removed RLSSMConfig conversion helpers/tests and injected the built RLSSM
loglikOp directly ontoRLSSMConfigbefore passing it intoHSSMBase. - Made
HSSMBasecopymodel_config.list_paramsintoself.list_paramsto avoid mutating the config’s list.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
tests/test_rlssm_config.py |
Updates get_defaults expectations; removes conversion-related tests. |
src/hssm/rl/rlssm.py |
Passes RLSSMConfig directly to HSSMBase; injects built loglik Op and uses self.list_params for distribution building. |
src/hssm/config.py |
Changes RLSSMConfig defaults behavior and deletes conversion methods. |
src/hssm/base.py |
Copies list_params from config to instance to prevent config mutation. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
…Config instance, preserving the original configuration.
There was a problem hiding this comment.
Pull request overview
This PR refactors RLSSM configuration/model initialization so RLSSMConfig is used directly (no conversion to Config), and adjusts prior/bounds handling so “safe” priors are derived from bounds rather than fixed scalar defaults.
Changes:
- Remove RLSSMConfig→Config conversion helpers and wire RLSSM to inject the built differentiable
loglikOp directly intoRLSSMConfig. - Update
RLSSMConfig.get_defaults()to always returnNonefor the prior value, and add early validation to require bounds for every parameter inlist_params. - Make
HSSMBasecopymodel_config.list_paramsintoself.list_paramsto avoid mutation bleed-through; update tests accordingly.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
src/hssm/rl/rlssm.py |
Injects loglik Op into RLSSMConfig via dataclasses.replace, uses self.list_params for distribution building, updates typing/docs. |
src/hssm/config.py |
Enforces bounds coverage in RLSSMConfig.validate() and changes get_defaults() to return (None, bounds) consistently. |
src/hssm/base.py |
Copies list_params from config into self.list_params to avoid mutating caller/config state. |
tests/test_rlssm_config.py |
Updates fixtures/expectations, adds validation test for missing bounds, removes conversion tests. |
Comments suppressed due to low confidence (2)
src/hssm/rl/rlssm.py:234
params_is_trialwiseis computed fromlist_paramsincludingp_outlier, butmake_distribution()stripsp_outlierfromdist_paramsbefore applyingparams_is_trialwiseindexing. This currently leaves an unused trailing boolean and makes the intended alignment unclear. Consider buildingparams_is_trialwisefrom the params actually passed tologlik(i.e., excludingp_outlierwhen present) to keep lengths consistent and avoid future misalignment bugs if more scalar params are added.
# p_outlier is a scalar mixture weight (not trialwise); every other
# RLSSM parameter is trialwise (the Op receives one value per trial).
params_is_trialwise = [name != "p_outlier" for name in list_params]
src/hssm/config.py:479
RLSSMConfig.get_defaults()is typed as returningtuple[float | None, ...], but the first element is conceptually a prior spec (and in this implementation is alwaysNone). To stay compatible withConfig.get_defaults()and theParamsmachinery (which treats this as a prior spec), it would be clearer/safer to type this astuple[ParamSpec | None, tuple[float, float] | None](or similar) rather thanfloat | None.
def get_defaults(
self, param: str
) -> tuple[float | None, tuple[float, float] | None]:
"""Return default value and bounds for a parameter.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
… for default parameter bounds
… add pickle round-trip verification
…ror for unsupported usage
There was a problem hiding this comment.
Pull request overview
This PR simplifies RLSSM configuration flow by removing legacy conversions to Config, injecting the differentiable loglik Op directly into RLSSMConfig, and tightening RLSSMConfig validation (especially around bounds) while updating tests to reflect the new semantics.
Changes:
- Remove
RLSSMConfig→Configconversion methods and update RLSSM to pass anRLSSMConfig(with injectedloglik) directly intoHSSMBase. - Add early validation that every
list_paramsentry has a correspondingboundsentry; adjustget_defaults()to returnNonefor the default prior so priors are derived from bounds. - Ensure
HSSMBaseworks on its own copy oflist_params(including lapse-drivenp_outlier) and update tests, including a new cloudpickle round-trip test.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
src/hssm/config.py |
Tightens RLSSMConfig validation, changes defaults behavior, and removes legacy conversion APIs. |
src/hssm/rl/rlssm.py |
Uses dataclasses.replace to inject loglik/backend into RLSSMConfig and aligns likelihood construction with HSSMBase’s self.list_params. |
src/hssm/base.py |
Copies model_config.list_params into self.list_params to avoid mutating caller config and to safely append p_outlier. |
tests/test_rlssm_config.py |
Updates fixtures and adds coverage for missing-bounds validation and from_defaults() raising. |
tests/test_rlssm.py |
Updates lapse-related param alignment test and adds pickle round-trip coverage. |
Comments suppressed due to low confidence (1)
src/hssm/config.py:498
RLSSMConfig.get_defaults()now always returnsNonefor the prior, but its return annotation is stilltuple[float | None, ...]. This is both misleading (the prior type is not a float anymore) and inconsistent withConfig.get_defaults()which returns aParamSpec | None. Consider updating the signature to returntuple[ParamSpec | None, tuple[float, float] | None](or at leastAny | Nonefor the first element) to reflect actual usage inParams.make_param_from_defaults.
def get_defaults(
self, param: str
) -> tuple[float | None, tuple[float, float] | None]:
"""Return default value and bounds for a parameter.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
…asses-basemodelconfig-only-dict-supported' into inject-RLSSMConfig-directly-into-HSSMBase
digicosmos86
left a comment
There was a problem hiding this comment.
Definitely much improved over the original. Still have a few questions though
AlexanderFengler
left a comment
There was a problem hiding this comment.
Pretty much good to go, left some small comments.
| "learning_process", | ||
| "response", | ||
| "decision_process_loglik_kind", | ||
| "learning_process_loglik_kind", |
There was a problem hiding this comment.
as per our discussions we shouldn't call this loglik_kind here for the learning process.
It's not a log likelihood what we are computing at that level, it's just a trajectory over time (trials) for one (or more) parameters of the model.
There was a problem hiding this comment.
@AlexanderFengler Looking at my notes here: is this the one parameter you said could be specified from other parameters? Also what should the correct name be? Currently in the tests, the value used is always "blackbox". @krishnbera
There was a problem hiding this comment.
Logically shouldn't really by blackbox, because the learning rules we specify so far, are all analytical / differentiable.
For a name we can just call it learning_process_kind, it's just not a log likelihood, which was the misleading part.
Then if any of learning process of decision process components is blackbox --> entire function will be blackbox
If they are both differentiable the entire thing should be differentiable (but here we need to be careful about how the mix and match works in case we have jax and pytensor components).
In principle we should be able to mix analytical and approx_differentiable no problem, from a backend perspective it doesn't really matter if a function is a network or a collection of operations to represent an analytical computation.
For original construction, we need to know if it's one or the other (network/explicitly defined analytical function), because we need to call the right constructor, but ultimately it's differentiable.
…ed data constants
This pull request makes several important improvements to the RLSSM configuration and model pipeline in the
hssmpackage. The changes clarify the intended use of RLSSM configuration objects, enforce stricter validation (especially for parameter bounds), simplify how parameter defaults and priors are handled, and update the RLSSM model to use its configuration more robustly. Additionally, new tests are added to ensure correct behavior and serialization.Configuration and Validation Improvements:
RLSSMConfig.from_defaultsnow raisesNotImplementedErrorto prevent accidental misuse, clarifying that RLSSM models must be constructed viafrom_rlssm_dictor the constructor, not via shared defaults.validate()method inRLSSMConfignow checks that every parameter inlist_paramshas a corresponding entry inbounds, raising a clear error if not, which helps catch configuration mistakes early.get_defaults()method inRLSSMConfigis simplified: it always returns(None, bounds)for a parameter, ensuring that priors are assigned from bounds rather than fixed values, and removing unnecessary conversion methods.RLSSM Model Pipeline Updates:
RLSSMConfig) directly asmodel_config(with the differentiable Op and backend injected), rather than converting it to a genericConfig, ensuring all RLSSM-specific fields are preserved and used consistently. [1] [2] [3]_make_model_distributionmethod now uses the correctlist_paramsand accesses the log-likelihood Op directly from the RLSSM config, improving clarity and type safety. [1] [2] [3]Testing Enhancements:
from_defaultsraises as intended, verifyget_defaultsbehavior, and confirm RLSSM objects can be correctly serialized and deserialized withcloudpickle. [1] [2]Other Notable Changes:
Code Consistency and Safety:
list_paramsis always a list (not a tuple) in the base model, and that all code paths expect and enforce this invariant. [1] [2]Overall, these changes make the RLSSM configuration and model pipeline more robust, easier to use, and safer against common misconfigurations.
Configuration and Validation Improvements
RLSSMConfig.from_defaultsnow raisesNotImplementedErrorto prevent misuse and clarify the correct construction path for RLSSM models.RLSSMConfig.validate()now checks that all parameters inlist_paramshave correspondingbounds, raising a clear error if any are missing.get_defaults()inRLSSMConfignow always returns(None, bounds)for a parameter, ensuring priors are assigned from bounds and simplifying the interface.RLSSM Model Pipeline Updates
RLSSMConfig(with Op and backend injected) asmodel_config, instead of converting to a genericConfig, preserving all RLSSM-specific fields. [1] [2] [3]_make_model_distributionuses the correctlist_paramsand accesses the log-likelihood Op directly frommodel_config, improving clarity and correctness. [1] [2] [3]Testing Enhancements
from_defaultsraising,get_defaultscases, and cloudpickle round-trip serialization for RLSSM models. [1] [2]Code Consistency and Safety
list_paramsis always a list, not a tuple, throughout the base and RLSSM model code. [1] [2]Documentation and Clarity