-
Notifications
You must be signed in to change notification settings - Fork 55
KL Adaptive LR for PPO and LR schedule for SAC/TQC #72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
4d9c49c
Only check for terminated episodes
araffin c98ebde
Start adding ortho init
araffin cd2ef11
Add SimbaPolicy for PPO
araffin 0461871
Try adding ortho init to SAC
araffin 3e42262
Enable lr schedule for PPO
araffin 101c08d
Allow to pass lr, prepare for adaptive lr
araffin ef2321d
Implement adaptive lr
araffin 96bf978
Add small test
araffin 6e17def
Refactor adaptive lr
araffin 5832702
Add adaptive lr for SAC
araffin ab33983
Fix qf_learning_rate
araffin 163acbe
Revert "Fix qf_learning_rate"
araffin 0d83720
Revert "Add adaptive lr for SAC"
araffin 85d4f23
Revert kl div for SAC changes
araffin dc6bf4e
Revert dist.mode() in two lines
araffin f0235e6
Cleanup code
araffin 0e14aad
Add support for Gaussian actor for SAC
araffin 1a8063a
Enable Gaussian actor for TQC
araffin cefbd78
Log std too
araffin a809a01
Avoid NaN in kl div approx
araffin c92d840
Allow to use layer_norm in actor
araffin f54697a
Reformat
araffin 7afa84f
Allow max grad norm for TQC and fix optimizer class
araffin 3af7c93
Comment out max grad norm
araffin 3f9727b
Update to schedule classes
araffin d86e89d
Add lr schedule support for TQC
araffin c668fd1
Revert experimental changes and add support for lr schedule for SAC
araffin 22b3e54
Add test for adaptive kl div, remove squash output param
araffin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| from dataclasses import dataclass | ||
|
|
||
| import numpy as np | ||
|
|
||
|
|
||
| @dataclass | ||
| class KLAdaptiveLR: | ||
| """Adaptive lr schedule, see https://arxiv.org/abs/1707.02286""" | ||
|
|
||
| # If set will trigger adaptive lr | ||
| target_kl: float | ||
| current_adaptive_lr: float | ||
| # Values taken from https://github.com/leggedrobotics/rsl_rl | ||
| min_learning_rate: float = 1e-5 | ||
| max_learning_rate: float = 1e-2 | ||
| kl_margin: float = 2.0 | ||
| # Divide or multiply the lr by this factor | ||
| adaptive_lr_factor: float = 1.5 | ||
|
|
||
| def update(self, kl_div: float) -> None: | ||
| if kl_div > self.target_kl * self.kl_margin: | ||
| self.current_adaptive_lr /= self.adaptive_lr_factor | ||
| elif kl_div < self.target_kl / self.kl_margin: | ||
| self.current_adaptive_lr *= self.adaptive_lr_factor | ||
|
|
||
| self.current_adaptive_lr = np.clip(self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -44,6 +44,8 @@ class Actor(nn.Module): | |||
| # For MultiDiscrete | ||||
| max_num_choices: int = 0 | ||||
| split_indices: np.ndarray = field(default_factory=lambda: np.array([])) | ||||
| # Last layer with small scale | ||||
|
||||
| # Last layer with small scale |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The learning rate update logic is duplicated in both on-policy and off-policy algorithm classes; consider refactoring this into a shared utility to reduce code duplication.