diff --git a/1.dqn.ipynb b/1.dqn.ipynb index b3f33f4..400dc1f 100644 --- a/1.dqn.ipynb +++ b/1.dqn.ipynb @@ -172,11 +172,11 @@ " super(DQN, self).__init__()\n", " \n", " self.layers = nn.Sequential(\n", - " nn.Linear(env.observation_space.shape[0], 128),\n", + " nn.Linear(num_inputs, 128),\n", " nn.ReLU(),\n", " nn.Linear(128, 128),\n", " nn.ReLU(),\n", - " nn.Linear(128, env.action_space.n)\n", + " nn.Linear(128, num_actions)\n", " )\n", " \n", " def forward(self, x):\n", @@ -186,7 +186,7 @@ " if random.random() > epsilon:\n", " state = Variable(torch.FloatTensor(state).unsqueeze(0), volatile=True)\n", " q_value = self.forward(state)\n", - " action = q_value.max(1)[1].data[0]\n", + " action = q_value.max(1)[1].item()\n", " else:\n", " action = random.randrange(env.action_space.n)\n", " return action"