-
Notifications
You must be signed in to change notification settings - Fork 19
Rlssm class make model dist #915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
cpaniaguam
wants to merge
32
commits into
cp-main-sb
Choose a base branch
from
rlssm-class-make-model-dist
base: cp-main-sb
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
20ddc4c
Add ssm_logp_func to RLSSMConfig and update validation tests
cpaniaguam d97dcee
Add RLSSM model and utilities for reinforcement learning integration
cpaniaguam a6a0238
Refactor RLSSM parameter handling and add custom prefix resolution fo…
cpaniaguam d880977
Add tests for RLSSM class covering initialization, validation, and mo…
cpaniaguam bef8d6c
Refactor loglik handling in RLSSM to improve type safety with casting
cpaniaguam 3981ef6
Add NaN value check for participant column in validate_balanced_panel…
cpaniaguam d84a800
Add validation for ssm_logp_func in RLSSMConfig to ensure it is calla…
cpaniaguam 15ad6e2
Add exclude rules for ruff and mypy hooks to skip tests directory
cpaniaguam 262ec07
Add validation tests for ssm_logp_func in RLSSMConfig to ensure it is…
cpaniaguam 381275a
Add tests for NaN participant_id and unannotated ssm_logp_func in RLSSM
cpaniaguam 0e9ba42
Reject missing data and deadline handling in RLSSM initialization to …
cpaniaguam 4f28c68
Add tests to validate error handling for missing data and deadline in…
cpaniaguam 5e9f566
Refactor path handling for loading RLDM fixture dataset in tests
cpaniaguam 67ac2ce
Add fixture to set floatX to float32 for module tests
cpaniaguam e1c05df
Ensure params_is_trialwise aligns with list_params in RLSSM initializ…
cpaniaguam 564232b
Clarify comments on default_priors in ModelConfig and remove unnecess…
cpaniaguam bafc037
Update RLSSM to use to_numpy(copy=True) for extra_fields and add test…
cpaniaguam ba358a4
Refactor parameter name resolution in RLSSM to handle underscores cor…
cpaniaguam 0bfa755
Add test for _get_prefix method in RLSSM to ensure token-based matching
cpaniaguam 5b8a16a
Refactor RLSSMConfig.from_rlssm_dict to remove model_name parameter a…
cpaniaguam f69f2b6
Fix comment in test_rlssm.py to clarify output shape of log-likelihoo…
cpaniaguam bad943d
Update RLSSMConfig documentation to mark description as required
cpaniaguam 241aad2
Add ssm_logp_func to RLSSM_REQUIRED_FIELDS and update RLSSMConfig ini…
cpaniaguam ca3816d
Add dummy ssm_logp_func to tests and validate its presence in RLSSMCo…
cpaniaguam 827025c
Remove unused logging import from rlssm.py
cpaniaguam 292d6f0
Remove redundant exclude rule for ruff-format in pre-commit configura…
cpaniaguam 3b1aaf4
Add to_model_config method to RLSSMConfig for ModelConfig conversion
cpaniaguam 0678c45
Refactor RLSSM to delegate ModelConfig construction to RLSSMConfig an…
cpaniaguam 26a9336
Integrate Config and RLSSMConfig into HSSM and RLSSM classes for impr…
cpaniaguam cd660da
Update choices type from list to tuple for consistency in BaseModelCo…
cpaniaguam db308c6
Update choices type from list to tuple in test_constructor for consis…
cpaniaguam ef000cd
Fix formatting of error messages in TestRLSSMConfigValidation for con…
cpaniaguam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -264,8 +264,6 @@ class HSSMBase(ABC, DataValidatorMixin, MissingDataMixin): | |
| The jitter value for the initial values. | ||
| """ | ||
|
|
||
| config_class = Config | ||
|
|
||
| def __init__( | ||
| self, | ||
| data: pd.DataFrame, | ||
|
|
@@ -529,7 +527,8 @@ def _build_model_config( | |
| A complete Config object with choices and other settings applied. | ||
| """ | ||
| # Start with defaults | ||
| config = cls.config_class.from_defaults(model, loglik_kind) | ||
| # get_config_class is provided by Config/RLSSMConfig mixin through MRO | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does RLSSMConfig show up here in this file?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| config = cls.get_config_class().from_defaults(model, loglik_kind) # type: ignore[attr-defined] | ||
|
|
||
| # Handle user-provided model_config | ||
| if model_config is not None: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| """Reinforcement learning extensions for HSSM. | ||
|
|
||
| This sub-package provides: | ||
|
|
||
| - :class:`~hssm.rl.rlssm.RLSSM` — the RL + SSM model class. | ||
| - :func:`~hssm.rl.utils.validate_balanced_panel` — panel-balance utility. | ||
| - :mod:`hssm.rl.likelihoods` — log-likelihood builders | ||
| (:func:`~hssm.rl.likelihoods.builder.make_rl_logp_func`, | ||
| :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`). | ||
| """ | ||
|
|
||
| from .rlssm import RLSSM | ||
| from .utils import validate_balanced_panel | ||
|
|
||
| __all__ = [ | ||
| "RLSSM", | ||
| "validate_balanced_panel", | ||
| ] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.