-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDoubleDQN.py
More file actions
59 lines (46 loc) · 2.52 KB
/
DoubleDQN.py
File metadata and controls
59 lines (46 loc) · 2.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
import torch as th
from torch.nn import functional as F
from stable_baselines3 import DQN
class DoubleDQN(DQN):
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update learning rate according to schedule
self._update_learning_rate(self.policy.optimizer)
losses = []
for _ in range(gradient_steps):
# Sample replay buffer
replay_date = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
# Do not backpropagate gradient to the target network
with th.no_grad():
# Compute the next Q-values using the target network
next_q_values = self.q_net_target(replay_date.next_observations)
# Decouple action selection from value estimation
# Compute q-values for the next observation using the online q-net
next_q_values_online = self.q_net(replay_date.next_observations)
# Select action with online network
next_actions_online = th.argmax(next_q_values_online, dim=1)
# Estimate the q-values for the selected actions using target q network
next_q_values = th.gather(next_q_values, dim=1, index=next_actions_online.unsqueeze(-1))
# 1-step TD target
target_q_values = replay_date.rewards + (1 - replay_date.dones) * self.gamma * next_q_values
# Get current Q-values estimates
current_q_values = self.q_net(replay_date.observations)
# Retrieve the q-values for the actions from the replay buffer
current_q_values = th.gather(current_q_values, dim=1, index=replay_date.actions.long())
# Check the shape
assert current_q_values.shape == target_q_values.shape
# Compute loss
loss = F.smooth_l1_loss(current_q_values, target_q_values)
losses.append(loss.item())
# Optimize the q-network
self.policy.optimizer.zero_grad()
loss.backward()
# Clip gradient norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
# Increase update counter
self._n_updates += gradient_steps
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/loss", np.mean(losses))