-
Notifications
You must be signed in to change notification settings - Fork 36
Expand file tree
/
Copy pathfql.py
More file actions
270 lines (228 loc) · 10.1 KB
/
fql.py
File metadata and controls
270 lines (228 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import copy
from typing import Any
import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.encoders import encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import ActorVectorField, Value
class FQLAgent(flax.struct.PyTreeNode):
"""Flow Q-learning (FQL) agent."""
rng: Any
network: Any
config: Any = nonpytree_field()
def critic_loss(self, batch, grad_params, rng):
"""Compute the FQL critic loss."""
rng, sample_rng = jax.random.split(rng)
next_actions = self.sample_actions(batch['next_observations'], seed=sample_rng)
next_actions = jnp.clip(next_actions, -1, 1)
next_qs = self.network.select('target_critic')(batch['next_observations'], actions=next_actions)
if self.config['q_agg'] == 'min':
next_q = next_qs.min(axis=0)
else:
next_q = next_qs.mean(axis=0)
target_q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_q
q = self.network.select('critic')(batch['observations'], actions=batch['actions'], params=grad_params)
critic_loss = jnp.square(q - target_q).mean()
return critic_loss, {
'critic_loss': critic_loss,
'q_mean': q.mean(),
'q_max': q.max(),
'q_min': q.min(),
}
def actor_loss(self, batch, grad_params, rng):
"""Compute the FQL actor loss."""
batch_size, action_dim = batch['actions'].shape
rng, x_rng, t_rng = jax.random.split(rng, 3)
# BC flow loss.
x_0 = jax.random.normal(x_rng, (batch_size, action_dim))
x_1 = batch['actions']
t = jax.random.uniform(t_rng, (batch_size, 1))
x_t = (1 - t) * x_0 + t * x_1
vel = x_1 - x_0
pred = self.network.select('actor_bc_flow')(batch['observations'], x_t, t, params=grad_params)
bc_flow_loss = jnp.mean((pred - vel) ** 2)
# Distillation loss.
rng, noise_rng = jax.random.split(rng)
noises = jax.random.normal(noise_rng, (batch_size, action_dim))
target_flow_actions = self.compute_flow_actions(batch['observations'], noises=noises)
actor_actions = self.network.select('actor_onestep_flow')(batch['observations'], noises, params=grad_params)
distill_loss = jnp.mean((actor_actions - target_flow_actions) ** 2)
# Q loss.
actor_actions = jnp.clip(actor_actions, -1, 1)
qs = self.network.select('critic')(batch['observations'], actions=actor_actions)
q = jnp.mean(qs, axis=0)
q_loss = -q.mean()
if self.config['normalize_q_loss']:
lam = jax.lax.stop_gradient(1 / jnp.abs(q).mean())
q_loss = lam * q_loss
# Total loss.
actor_loss = bc_flow_loss + self.config['alpha'] * distill_loss + q_loss
# Additional metrics for logging.
actions = self.sample_actions(batch['observations'], seed=rng)
mse = jnp.mean((actions - batch['actions']) ** 2)
return actor_loss, {
'actor_loss': actor_loss,
'bc_flow_loss': bc_flow_loss,
'distill_loss': distill_loss,
'q_loss': q_loss,
'q': q.mean(),
'mse': mse,
}
@jax.jit
def total_loss(self, batch, grad_params, rng=None):
"""Compute the total loss."""
info = {}
rng = rng if rng is not None else self.rng
rng, actor_rng, critic_rng = jax.random.split(rng, 3)
critic_loss, critic_info = self.critic_loss(batch, grad_params, critic_rng)
for k, v in critic_info.items():
info[f'critic/{k}'] = v
actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
for k, v in actor_info.items():
info[f'actor/{k}'] = v
loss = critic_loss + actor_loss
return loss, info
def target_update(self, network, module_name):
"""Update the target network."""
new_target_params = jax.tree_util.tree_map(
lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
self.network.params[f'modules_{module_name}'],
self.network.params[f'modules_target_{module_name}'],
)
network.params[f'modules_target_{module_name}'] = new_target_params
@jax.jit
def update(self, batch):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
self.target_update(new_network, 'critic')
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
seed=None,
temperature=1.0,
):
"""Sample actions from the one-step policy."""
action_seed, noise_seed = jax.random.split(seed)
noises = jax.random.normal(
action_seed,
(
*observations.shape[: -len(self.config['ob_dims'])],
self.config['action_dim'],
),
)
actions = self.network.select('actor_onestep_flow')(observations, noises)
actions = jnp.clip(actions, -1, 1)
return actions
@jax.jit
def compute_flow_actions(
self,
observations,
noises,
):
"""Compute actions from the BC flow model using the Euler method."""
if self.config['encoder'] is not None:
observations = self.network.select('actor_bc_flow_encoder')(observations)
actions = noises
# Euler method.
for i in range(self.config['flow_steps']):
t = jnp.full((*observations.shape[:-1], 1), i / self.config['flow_steps'])
vels = self.network.select('actor_bc_flow')(observations, actions, t, is_encoded=True)
actions = actions + vels / self.config['flow_steps']
actions = jnp.clip(actions, -1, 1)
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
ex_times = ex_actions[..., :1]
ob_dims = ex_observations.shape[1:]
action_dim = ex_actions.shape[-1]
# Define encoders.
encoders = dict()
if config['encoder'] is not None:
encoder_module = encoder_modules[config['encoder']]
encoders['critic'] = encoder_module()
encoders['actor_bc_flow'] = encoder_module()
encoders['actor_onestep_flow'] = encoder_module()
# Define networks.
critic_def = Value(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
num_ensembles=2,
encoder=encoders.get('critic'),
)
actor_bc_flow_def = ActorVectorField(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
layer_norm=config['actor_layer_norm'],
encoder=encoders.get('actor_bc_flow'),
)
actor_onestep_flow_def = ActorVectorField(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
layer_norm=config['actor_layer_norm'],
encoder=encoders.get('actor_onestep_flow'),
)
network_info = dict(
critic=(critic_def, (ex_observations, ex_actions)),
target_critic=(copy.deepcopy(critic_def), (ex_observations, ex_actions)),
actor_bc_flow=(actor_bc_flow_def, (ex_observations, ex_actions, ex_times)),
actor_onestep_flow=(actor_onestep_flow_def, (ex_observations, ex_actions)),
)
if encoders.get('actor_bc_flow') is not None:
# Add actor_bc_flow_encoder to ModuleDict to make it separately callable.
network_info['actor_bc_flow_encoder'] = (encoders.get('actor_bc_flow'), (ex_observations,))
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
params = network.params
params['modules_target_critic'] = params['modules_critic']
config['ob_dims'] = ob_dims
config['action_dim'] = action_dim
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
agent_name='fql', # Agent name.
ob_dims=ml_collections.config_dict.placeholder(list), # Observation dimensions (will be set automatically).
action_dim=ml_collections.config_dict.placeholder(int), # Action dimension (will be set automatically).
lr=3e-4, # Learning rate.
batch_size=256, # Batch size.
actor_hidden_dims=(512, 512, 512, 512), # Actor network hidden dimensions.
value_hidden_dims=(512, 512, 512, 512), # Value network hidden dimensions.
layer_norm=True, # Whether to use layer normalization.
actor_layer_norm=False, # Whether to use layer normalization for the actor.
discount=0.99, # Discount factor.
tau=0.005, # Target network update rate.
q_agg='mean', # Aggregation method for target Q values.
alpha=10.0, # BC coefficient (need to be tuned for each environment).
flow_steps=10, # Number of flow steps.
normalize_q_loss=False, # Whether to normalize the Q loss.
encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.).
)
)
return config