-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
141 lines (113 loc) · 4.03 KB
/
utils.py
File metadata and controls
141 lines (113 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
This file contains general helper functions for the Risk environment
"""
import numpy as np
def validate_q_func_for_argmax(q_func, valid_mask):
"""
Returns a q function with negative infinity in indices zero'd by valid_mask
:param q_func: float vector to correct for validation
:param valid_mask: int vector of valid actions (binary)
:return valid_q_func: vector of size as q_func, with -inf in places set by valid_mask
"""
nA = len(valid_mask)
valid_q_func = np.zeros(nA)
any_valid_action = False
if not (len(q_func) == nA):
print("Q function and mask different sizes")
return -1
for ii in range(nA):
if valid_mask[ii] == 1:
any_valid_action = True
valid_q_func[ii] = q_func[ii]
else:
valid_q_func[ii] = float("-inf")
if any_valid_action == False:
print("Warning: no valid actions")
return -1
return valid_q_func
# https://stackoverflow.com/questions/34968722/how-to-implement-the-softmax-function-in-python
def softmax_valid(q_func):
"""
Returns the softmax output of a general linear activation
"""
EFFECTIVE_ZERO = 1e-20
# print(q_func)
for i in range(len(q_func)):
if abs(q_func[i]) <= EFFECTIVE_ZERO:
q_func[i] = float("-inf") ####### Invalidate q's
num = np.exp(q_func)
den = np.sum(num)
softmax_q = num/den
# print(softmax_q)
return softmax_q
def choose_by_weight(q_func):
"""
Returns an index in q_func (action choice) based on relative weights in the q function
"""
min_q = min(q_func)
sum_q = np.sum(q_func)
if (sum_q == 0):
print("WARNING: weighted choices sum to zero, sampling randomly")
return np.random.randint(0, len(q_func))
if (min_q < 0):
print("WARNING: weighted choices non-negative, sampling incorrect")
q_func = np.array(q_func) # In case passed a list, which doesn't support '-=''
q_func -= min_q
probs = q_func/np.sum(q_func)
choice = np.random.choice(a=range(len(q_func)), size=1, p=probs)
return choice[0]
def epsilon_greedy_valid(q_func, valid_mask, epsilon):
"""
Returns an epsilon greedy action from a subset of function defined by mask
Only chooses valid actions as specified by the mask
:param q_func: float vector to return argmax in greedy case
:param valid_mask: int vector of valid actions
:param epsilon: probability under which to choose non-greedily
:return arg: int choice
"""
nA = len(valid_mask)
if not (len(q_func) == nA):
print("Q function and mask different sizes")
return -1
eps_choices = np.sum(valid_mask) - 1
valid_q_func = []
valid_q_to_orig_q_map = []
for ii in range(nA):
if valid_mask[ii] == 1:
valid_q_func.append(q_func[ii])
valid_q_to_orig_q_map.append(ii)
if len(valid_q_func) == 0:
print("No valid actions")
return -1
# print(valid_q_func)
# print(valid_q_to_orig_q_map)
valid_action = epsilon_greedy(valid_q_func, epsilon)
# print(valid_action)
action = valid_q_to_orig_q_map[valid_action]
return action
def epsilon_greedy(q_func, epsilon):
"""
Defines a policy which acts greedily except for epsilon exceptions
:param q_func: q function returned by an attack network
:param epsilon: the threshold value
:return index: int the index of the corresponding action
"""
eps_choices = len(q_func) - 1
if eps_choices == 0:
return -1
choice = np.random.uniform()
max_action = np.argmax(q_func)
# print(choice)
# print("Max action is {}".format(max_action))
if choice > epsilon:
return max_action
else:
eps_slice = epsilon/eps_choices
for act_slice in range(eps_choices):
# print(eps_slice*(1+act_slice))
if choice < (eps_slice*(1+act_slice)):
action = act_slice
break
if action >= max_action: # Increment if past max_action
action += 1
return action