-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
54 lines (41 loc) · 1.39 KB
/
train.py
File metadata and controls
54 lines (41 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gym
import numpy as np
import utils
import matplotlib.pyplot as plt
def sample(theta, env, N):
""" samples N trajectories using the current policy
:param theta: the model parameters (shape d x 1)
:param env: the environment used to sample from
:param N: number of trajectories to sample
:return:
trajectories_gradients: lists with sublists for the gradients for each trajectory rollout
trajectories_rewards: lists with sublists for the rewards for each trajectory rollout
Note: the maximum trajectory length is 200 steps
"""
total_rewards = []
total_grads = []
# TODO
return total_grads, total_rewards
def train(N, T, delta):
"""
:param N: number of trajectories to sample in each time step
:param T: number of iterations to train the model
:param delta: trust region size
:return:
theta: the trained model parameters
avg_episodes_rewards: list of average rewards for each time step
"""
theta = np.random.rand(100,1)
env = gym.make('CartPole-v0')
env.seed(12345)
episode_rewards = []
# TODO
return theta, episode_rewards
if __name__ == '__main__':
np.random.seed(1234)
theta, episode_rewards = train(N=100, T=20, delta=1e-2)
plt.plot(episode_rewards)
plt.title("avg rewards per timestep")
plt.xlabel("timestep")
plt.ylabel("avg rewards")
plt.show()