-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
109 lines (95 loc) · 4.38 KB
/
plot.py
File metadata and controls
109 lines (95 loc) · 4.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import numpy as np
import argparse
# List of event log paths and corresponding experiment names
event_logs = [
# ('runs/Dec07_15-45-29_cassie_attn', 'Comm-Attn'),
# ('runs/Dec07_15-59-06_cassie_baseline', 'Baseline'),
('runs/Dec07_15-45-29_cassie_attn', 'Neighborhood-all'),
('runs/Dec07_18-58-53_cassie_ca_nr_1', 'Neighborhood-1'),
('runs/Dec07_18-59-22_cassie_ca_nr_3', 'Neighborhood-3'),
('runs/Dec07_17-56-20_cassie_ca_nr_5', 'Neighborhood-5'),
# ('runs/Dec07_17-56-57_cassie_ca_nr_10', 'Neighborhood-10'),
]
# Load the event logs
event_accumulators = [(EventAccumulator(path), name) for path, name in event_logs]
for event_acc, _ in event_accumulators:
event_acc.Reload()
# Extract scalar data
def extract_scalar_data(event_acc, tag):
return event_acc.Scalars(tag)
# Helper function to extract steps and values
def extract_steps_and_values(scalar_data, window_size=500):
steps = [entry.step for entry in scalar_data[:window_size]]
values = [entry.value for entry in scalar_data[:window_size]]
return steps, values
# Helper function to calculate mean and standard deviation over a moving window
def moving_avg_std(values, window_size=10):
means = []
stds = []
for i in range(len(values) - window_size + 1):
window = values[i:i + window_size]
means.append(np.mean(window))
stds.append(np.std(window))
return np.array(means), np.array(stds)
# Function to update the plot
def update_plot(frame):
data = {}
for event_acc, name in event_accumulators:
event_acc.Reload()
return_data = extract_scalar_data(event_acc, 'Episode/Return')
value_loss_data = extract_scalar_data(event_acc, 'Episode/Value_Loss')
policy_loss_data = extract_scalar_data(event_acc, 'Episode/Policy_Loss')
steps_return, values_return = extract_steps_and_values(return_data)
steps_value_loss, values_value_loss = extract_steps_and_values(value_loss_data)
steps_policy_loss, values_policy_loss = extract_steps_and_values(policy_loss_data)
mean_return, std_return = moving_avg_std(values_return)
mean_value_loss, std_value_loss = moving_avg_std(values_value_loss)
mean_policy_loss, std_policy_loss = moving_avg_std(values_policy_loss)
steps_return_smooth = steps_return[len(steps_return) - len(mean_return):]
steps_value_loss_smooth = steps_value_loss[len(steps_value_loss) - len(mean_value_loss):]
steps_policy_loss_smooth = steps_policy_loss[len(steps_policy_loss) - len(mean_policy_loss):]
data[name] = {
'steps_return': steps_return_smooth,
'mean_return': mean_return,
'std_return': std_return,
'steps_value_loss': steps_value_loss_smooth,
'mean_value_loss': mean_value_loss,
'std_value_loss': std_value_loss,
'steps_policy_loss': steps_policy_loss_smooth,
'mean_policy_loss': mean_policy_loss,
'std_policy_loss': std_policy_loss,
}
plt.clf()
plot_data(1, data, 'return', 'Return', 'Episode Return')
plot_data(2, data, 'value_loss', 'Value Loss', 'Value Loss')
plot_data(3, data, 'policy_loss', 'Policy Loss', 'Policy Loss')
plt.tight_layout()
# Function to plot data
def plot_data(subplot_index, data, key, ylabel, title):
plt.subplot(3, 1, subplot_index)
for name, values in data.items():
steps = values[f'steps_{key}']
mean = values[f'mean_{key}']
std = values[f'std_{key}']
plt.plot(steps, mean, label=f'{name} {key.capitalize()} (smoothed)')
plt.fill_between(steps, mean - std, mean + std, alpha=0.3)
plt.xlabel('Steps')
plt.ylabel(ylabel)
plt.legend()
plt.title(title)
def main(update):
fig = plt.figure(figsize=(9, 12))
if update:
ani = FuncAnimation(fig, update_plot, interval=10000) # Update every 1 seconds
plt.show()
else:
update_plot(None)
plt.savefig("plot_ablation.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Plot training metrics.")
parser.add_argument('--update', default=False, action='store_true', help="Update the plot in real-time.")
args = parser.parse_args()
main(args.update)