-
Notifications
You must be signed in to change notification settings - Fork 58
draft_retrace #695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pytorch
Are you sure you want to change the base?
draft_retrace #695
Conversation
emailweixu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
take a look at
https://github.com/HorizonRobotics/alf/blob/pytorch/docs/contributing.rst
to format your code properly
alf/algorithms/td_loss.py
Outdated
| step_types=experience.step_type, | ||
| discounts=experience.discount * self._gamma) | ||
| else: | ||
| elif train_info == None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of checking whether train_info is None, you should add an argument in __init__ to indicate whether use retrace.
You should also change SarsaAlgorithm and SacAlgorithm to pass in train_info.
alf/utils/value_ops.py
Outdated
|
|
||
| return advs.detach() | ||
| ####### add for the retrace method | ||
| def generalized_advantage_estimation_retrace(importance_ratio, discounts, rewards, td_lambda, time_major, values, target_value,step_types): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please comment following the way of other functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also need unittest for this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- line too long
- add space after
, - comments for the function need to be added
alf/algorithms/sarsa_algorithm.py
Outdated
| info.target_critics) | ||
| loss_info = self._critic_losses[i](shifted_experience, critic, | ||
| target_critic) | ||
| target_critic,info) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add space after ,
alf/algorithms/td_loss.py
Outdated
| td_error_loss_fn=element_wise_squared_loss, | ||
| td_lambda=0.95, | ||
| normalize_target=False, | ||
| some-feature-retrace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to be removed
alf/algorithms/td_loss.py
Outdated
| some-feature-retrace | ||
| use_retrace=0, | ||
|
|
||
| pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to be removed
alf/algorithms/td_loss.py
Outdated
| self._debug_summaries = debug_summaries | ||
| self._normalize_target = normalize_target | ||
| self._target_normalizer = None | ||
| some-feature-retrace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove, seems to be the tags from a merge
alf/algorithms/td_loss.py
Outdated
|
|
||
|
|
||
| def forward(self, experience, value, target_value): | ||
| pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
| else: | ||
| scope = alf.summary.scope(self.__class__.__name__) | ||
| importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio( | ||
| action_distribution=train_info.action_distribution, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
format, line is too long
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not fixed?
Haichao-Zhang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be many format issues. You may need to follow the workflow here to setup the formatting tools and also get a reference of coding standard:
https://alf.readthedocs.io/en/latest/contributing.html#workflow
alf/algorithms/td_loss.py
Outdated
| where the generalized advantage estimation is defined as: | ||
| :math:`\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))` | ||
| use_retrace = 0 means one step or multi_step loss, use_retrace = 1 means retrace loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can change use_retrace use bool value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to update comment
alf/algorithms/td_loss.py
Outdated
|
|
||
| else: | ||
| scope = alf.summary.scope(self.__class__.__name__) | ||
| importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add space after ,
alf/utils/value_ops.py
Outdated
|
|
||
| return advs.detach() | ||
| ####### add for the retrace method | ||
| def generalized_advantage_estimation_retrace(importance_ratio, discounts, rewards, td_lambda, time_major, values, target_value,step_types): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- line too long
- add space after
, - comments for the function need to be added
| expected=expected) | ||
|
|
||
| class GeneralizedAdvantage_retrace_Test(unittest.TestCase): | ||
| """Tests for alf.utils.value_ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comments not correct
alf/algorithms/td_loss.py
Outdated
| where the generalized advantage estimation is defined as: | ||
| :math:`\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))` | ||
| use_retrace = 0 means one step or multi_step loss, use_retrace = 1 means retrace loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to update comment
alf/algorithms/td_loss.py
Outdated
| log_prob_clipping=0.0, | ||
| scope=scope, | ||
| check_numerics=False, | ||
| debug_summaries=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
debug_summaries= debug_summaries
|
|
||
|
|
||
| ####### add for the retrace method | ||
| def generalized_advantage_estimation_retrace(importance_ratio, discounts, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function can be merged with generalized_advantage_estimation function
| else: | ||
| scope = alf.summary.scope(self.__class__.__name__) | ||
| importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio( | ||
| action_distribution=train_info.action_distribution, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not fixed?
alf/algorithms/td_loss.py
Outdated
| target_value (torch.Tensor): the time-major tensor for the value at | ||
| each time step. This is used to calculate return. ``target_value`` | ||
| can be same as ``value``. | ||
| train_info (sarsa info, sac info): information used to calcuate importance_ratio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is sarsa info, sac info here? Can this function be used with other algorithms beyond sac and sarsa?
Change code in file value_ops and td_loss. Default value for train_info is None. If we give the train_info parameter and lambda is not equal to 1 and 0, we will use retrace method. So we do not need to change the code of sac_algorithm or sarsa_algorithm when other people do not want retrace method.