-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patha2c.py
More file actions
96 lines (83 loc) · 3.72 KB
/
a2c.py
File metadata and controls
96 lines (83 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import numpy as np
import tensorflow as tf
from diplomacy import Game
from diplomacy_research.models import state_space
from tornado import gen
from RL.reward import Reward, get_returns, get_average_reward
from RL.actor import ActorRL
from RL.critic import CriticRL
from tensorflow.keras.optimizers import Adam
# Killing optional CPU driver warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
class A2C:
def __init__(self, actor_rl, critic_rl):
"""
Initialize Actor, Critic model.
"""
self.actor = actor_rl
self.critic = critic_rl
self.optimizer = Adam(0.001)
@gen.coroutine
def generate_trajectory(self):
game = Game()
powers = list(game.powers)
np.random.shuffle(powers)
power1 = powers[0]
powers_others = powers[1:]
action_probs = []
orders = []
values = []
supply_centers = [{power1: game.get_centers(power1)}]
while not game.is_game_done:
order, action_prob = self.actor.get_orders(game, [power1])
orders_others = {
power_name: self.actor.get_orders(game, [power_name]) for
power_name in powers_others}
board = tf.convert_to_tensor(state_space.dict_to_flatten_board_state(game.get_state(), game.map),dtype=tf.float32)
board = tf.reshape(board,(1,81*35))
print("TEST")
state_value = self.critic.call(board)
# Indexing because get_orders can return a list of lists orders for multiple powers
game.set_orders(power1, order[0])
for power_name, power_orders in orders_others.items():
orders_list, probs = power_orders
print(orders_list)
game.set_orders(power_name, orders_list[0])
game.process()
# Collect data
supply_centers.append({power1: game.get_centers(power1)})
action_probs.append(action_prob)
orders.append(order)
values.append(state_value)
# local_rewards.append(reward_class.get_local_reward(power1))
# global_rewards.append(0 if not game.is_game_done else reward_class.get_terminal_reward(power1))
rewards = get_average_reward([supply_centers])
returns = get_returns([supply_centers]) # put in list to match shape of [bs, game_length, dict}
return action_probs, returns, values, rewards
def train(self, num_episodes):
"""
Self-play training loop for A2C. Generates a complete trajectory for one episode,
and then updates both the actor and critic networks using that trajectory
:param num_episodes: number of episodes to train the networks for
:returns: Total reward per episode
"""
eps_rewards = []
for eps in range(num_episodes):
with tf.GradientTape(persistent=True) as tape:
action_probs, returns, values, rewards = self.generate_trajectory()
actor_loss = self.actor.loss_function(
action_probs, values, returns) # + .05 * calc_entropy(policy[0])
critic_loss = self.critic.loss_function(
values, returns)
actor_grad = tape.gradient(
actor_loss, self.actor.trainable_variables)
critic_grad = tape.gradient(
critic_loss, self.critic.trainable_variables)
self.optimizer.apply_gradients(
zip(actor_grad, self.actor.trainable_variables))
self.optimizer.apply_gradients(
zip(critic_grad, self.critic.trainable_variables))
eps_rewards.append(sum(rewards))
print("A2C training episode number:", eps)
return eps_rewards