-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathARS.py
More file actions
43 lines (33 loc) · 1.16 KB
/
ARS.py
File metadata and controls
43 lines (33 loc) · 1.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Augmented Random Search (ARS) AI
# Importing the libraries
import os
import numpy as np
# Setting the hyper Parameters
class Hp():
def __init__(self):
self.nb_steps = 1000
self.episode_length = 1000
self.learning_rate = 0.02
self.nb_directions = 16
self.nb_best_directions = 16
assert self.nb_best_directions <= self.nb_directions
self.noise = 0.03
self.seed = 1
self.env_name = ''
# Normalizing the states
class Normalizer():
def __init__(self, nb_inputs):
self.n = np.zeros(nb_inputs)
self.mean = np.zeros(nb_inputs)
self.mean_diff = np.zeros(nb_inputs)
self.var = np.zeros(nb_inputs)
def observe(self, x):
self.n += 1
last_mean = self.mean.copy()
self.mean += (x- self.mean) / self.n
self.mean_diff += (x - last_mean) * (x - self.mean)
self.var = (self.mean_diff / self.n).clip(min = 1e-2)
def normalize(self, inputs):
obs_mean = self.mean
obs_std = np.sqrt(self.var)
return (inputs - obs_mean) / obs_std