Conversation
mighty/mighty_agents/base_agent.py
Outdated
|
|
||
|
|
||
| # 3) optionally overwrite next_s on truncation | ||
| if self.handle_timeout_termination: |
There was a problem hiding this comment.
Not sure I like the naming. What does it mean to "handle_timeout_termination"? Should be more expressive. Also: when do we want this? Always? On specific envs? Specific algos? I would actually assume always, since we only want the next_s for next action prediction and always final obs in the replay. In that case we don't need a flag at all.
There was a problem hiding this comment.
Removed the optional flag
| next_s, reward, terminated, truncated, infos = self.env.step(action) | ||
|
|
||
| # 2) decide which samples are true “done” | ||
| replay_dones = terminated # physics‐failure only |
There was a problem hiding this comment.
Comment here is env-specific. Also inconsistent: dones are always overwritten to real termination regardless of what the flag says.
mighty/mighty_agents/sac.py
Outdated
| # Pack transition | ||
| transition = TransitionBatch(curr_s, action, reward, next_s, dones) | ||
| # Pack transition | ||
| # `terminated` is used for physics failures in environments like `MightyEnv` |
There was a problem hiding this comment.
At least remove the weird AI comments
| elif isinstance(out, tuple) and len(out) == 4: | ||
| action = out[0] # [batch, action_dim] | ||
|
|
||
| print(f'Self Model : {self.model}') |
| print(f'Self Model : {self.model}') | ||
| log_prob = sample_nondeterministic_logprobs( | ||
| z=out[1], mean=out[2], log_std=out[3], sac=self.algo == "sac" | ||
| z=out[1], mean=out[2], log_std=out[3], sac=isinstance(self.model, SACModel) |
There was a problem hiding this comment.
Bad idea! What If I want to implement a different model class for SAC that e.g. handles prediction differently? Then the policy stops functioning.
| z: torch.Tensor, | ||
| mean: torch.Tensor, | ||
| log_std: torch.Tensor, | ||
| sac: bool = False |
There was a problem hiding this comment.
The flag is here to stay model agnostic. Now you make it impossible to add new model classes for SAC...
| # 4-tuple case (Tanh squashing): (action, z, mean, log_std) | ||
| elif isinstance(model_output, tuple) and len(model_output) == 4: | ||
| action, z, mean, log_std = model_output | ||
| log_prob = sample_nondeterministic_logprobs( |
There was a problem hiding this comment.
I don't understand the reason for changing this, it's the same code but longer and locking into a specific model class?
| return action.detach().cpu().numpy(), log_prob | ||
| else: | ||
| weighted_log_prob = log_prob * self.entropy_coefficient | ||
| weighted_log_prob = log_prob |
There was a problem hiding this comment.
This is strange, now both do the same?!
| log_prob = sample_nondeterministic_logprobs( | ||
| z=z, mean=mean, log_std=log_std, sac=self.algo == "sac" | ||
| ) | ||
| if not isinstance(self.model, SACModel): |
There was a problem hiding this comment.
Same issue as above, identical function longer and worse
| """ | ||
| feats = self.feature_extractor(state) | ||
| x = self.policy_net(feats) | ||
| x = self.policy_net(state) |
There was a problem hiding this comment.
Not in the mighty format. The separate feature extractor is there to have a predictable structure and access to a feature embedding "Mighty-er" format would be to have a feature extractor -> policy head and then a q_feature_extractor. No functional difference, but it's relevant for continuity between algos.
There was a problem hiding this comment.
Updated -- performance similar
Updates to SAC
handle_timeout_termination. When this is set to true, we treat the final states ofterminateddifferently fromtruncated. This is related to [Bug] Infinite horizon tasks are handled like episodic tasks DLR-RM/stable-baselines3#284- We now use the
policy_log_prob()from the SAC model exclusivley for the tanh correction instead ofsample_nondeterministic_logprobs. The latter can be potentially just made for PPO- Added
make_policy_head()to seperate hte policy head functionality- The SAC network forward method now handles action rescaling and log_prob resampling
- SAC update uses fresh samples for alpha update, and exponentiates log_alpha