-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathstore_heatmap.py
More file actions
88 lines (70 loc) · 3.29 KB
/
store_heatmap.py
File metadata and controls
88 lines (70 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import sys
import numpy as np
import argparse
import os
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
plt.switch_backend('agg') # disable window
def argparser():
parser = argparse.ArgumentParser()
parser.add_argument('--mechanism', type=str,)
parser.add_argument('--n_agent', type=int,)
parser.add_argument('--reward_pool', type=int,)
parser.add_argument('--review_history', type=int,)
parser.add_argument('--window', type=int,)
parser.add_argument('--n_episode', type=int, default=500)
parser.add_argument('--record_term_1', type=int, default=10)
parser.add_argument('--record_term_2', type=int, default=5)
parser.add_argument('--range_endeavor', type=int, default=10)
parser.add_argument('--n_average', type=int, default=100)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = argparser()
my_args = sys.argv
filename = "./data/{}_{}_{}_{}_{}.pkl".format(
my_args[1][2:], my_args[2][2:], my_args[3][2:], my_args[4][2:], my_args[5][2:])
"""load pkl files"""
with open(filename, 'rb') as f:
dict_ = pickle.load(f)
# heatmap
all_total_beta_lists = dict_['all_total_beta_lists'] # term1 = 10
all_details = dict_['all_details_total_beta_lists'] # term2 = 5
weighted_endeavor_list = []
for episode in range(int(args.n_episode / args.record_term_1) + 1):
avg_total_beta_lists = np.zeros((args.n_agent, args.range_endeavor))
# print("\n\nepisode {}".format(episode * args.record_term_1))
for i in range(args.n_average):
avg_total_beta_lists += np.array(
all_total_beta_lists[i][episode]) / args.n_average
# heatmap
# weighted average endeavor
weighted_endeavor = np.array(
[sum(avg_total_beta_lists[k] * np.arange(0., args.range_endeavor)) for k in range(args.n_agent)])
weighted_endeavor_list.append(weighted_endeavor)
if not os.path.exists("./visualization/{}/{}/{}/{}/{}/images".format(
my_args[1][2:], my_args[2][2:], my_args[3][2:], my_args[4][2:], my_args[5][2:])):
os.makedirs("./visualization/{}/{}/{}/{}/{}/images".format(
my_args[1][2:], my_args[2][2:], my_args[3][2:], my_args[4][2:], my_args[5][2:]))
# weighted average endeavor
fig = plt.figure()
ax = sns.heatmap(np.array(weighted_endeavor_list))
ax.xaxis.tick_top()
# writer.add_figure("weighted_avg_endeavor_heatmap", fig)
plt.savefig("./visualization/{}/{}/{}/{}/{}/images/weighted_endeavor".format(
my_args[1][2:], my_args[2][2:], my_args[3][2:], my_args[4][2:], my_args[5][2:]))
# plt.close(fig)
# details beta table heatmap
for episode in range(int(args.n_episode / args.record_term_2) + 1):
avg_details = np.zeros((args.n_agent, args.range_endeavor))
for i in range(args.n_average):
avg_details += np.array(all_details[i][episode]) / args.n_average
# beta_table
fig = plt.figure()
ax = sns.heatmap(avg_details.T)
ax.xaxis.tick_top()
# writer.add_figure("beta_table_heatmap", fig, episode)
plt.savefig("./visualization/{}/{}/{}/{}/{}/images/{}".format(
my_args[1][2:], my_args[2][2:], my_args[3][2:], my_args[4][2:], my_args[5][2:], episode))
# plt.close(fig)