-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
In the env_step() method, there's a discrepancy between:
- The action used for the hidden state update (line where
latent, _ = self.tssm(...)is called) - The action actually executed in the environment (which can be different due to random pre-filling)
As a demonstration:
# Generate action from policy
action = self.policy_network(feat).sample().cpu()
# Update hidden state with this action
latent["hidden"] = self.tssm.slice_hidden(latent["hidden"])
hidden = (latent, action)
# But then potentially override the action for random pre-filling
if (self.replay_buffer.num_steps < self.config.pre_fill_steps) and self.config.random_pre_fill_steps:
action = self.env.sample() # This is a DIFFERENT action
# Execute the (potentially different) action
obs = self.env.step(action.argmax(dim=-1) if self.config.policy_discrete else action)And then the new randomly sampled action is stored in the replay buffer.
This creates an inconsistency where:
- The hidden state (
latent) contains information computed using the policy's action - The replay buffer stores the actually executed action (which might be random)
- Future TSSM computations will use the hidden state that was computed with a different action than what was actually executed
This breaks the causal consistency of the world model, as the hidden state representation doesn't accurately reflect the true state transition that occurred in the environment.
Any clarification is appreciated!
sq0616
Metadata
Metadata
Assignees
Labels
No labels