-
Notifications
You must be signed in to change notification settings - Fork 43
Expand file tree
/
Copy pathactivation_utils.py
More file actions
183 lines (149 loc) · 6.45 KB
/
activation_utils.py
File metadata and controls
183 lines (149 loc) · 6.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from __future__ import annotations
import torch as t
class SparseAct:
"""
A SparseAct is a helper class which represents a vector in the sparse feature basis provided by an SAE, jointly with the SAE error term.
A SparseAct may have three fields:
act : the feature activations in the sparse basis
res : the SAE error term
resc : a contracted SAE error term, useful for when we want one number per feature and error (instead of having d_model numbers per error)
"""
def __init__(
self,
act: t.Tensor,
res: t.Tensor | None = None,
resc: t.Tensor | None = None, # contracted residual
) -> None:
self.act = act
self.res = res
self.resc = resc
def _map(self, f, aux=None) -> 'SparseAct':
kwargs = {}
if isinstance(aux, SparseAct):
for attr in ['act', 'res', 'resc']:
if getattr(self, attr) is not None and getattr(aux, attr) is not None:
kwargs[attr] = f(getattr(self, attr), getattr(aux, attr))
else:
for attr in ['act', 'res', 'resc']:
if getattr(self, attr) is not None:
kwargs[attr] = f(getattr(self, attr), aux)
return SparseAct(**kwargs)
def __mul__(self, other) -> SparseAct:
return self._map(lambda x, y: x * y, other)
def __rmul__(self, other) -> SparseAct:
# This will handle float/int * SparseAct by reusing the __mul__ logic
return self.__mul__(other)
def __matmul__(self, other: SparseAct) -> SparseAct:
assert self.res is not None and other.res is not None
# dot product between two SparseActs, except only the residual is contracted
return SparseAct(act = self.act * other.act, resc=(self.res * other.res).sum(dim=-1, keepdim=True))
def __add__(self, other) -> SparseAct:
return self._map(lambda x, y: x + y, other)
def __radd__(self, other: SparseAct) -> SparseAct:
return self.__add__(other)
def __sub__(self, other: SparseAct) -> SparseAct:
return self._map(lambda x, y: x - y, other)
def __truediv__(self, other) -> SparseAct:
if isinstance(other, SparseAct):
kwargs = {}
for attr in ['act', 'res', 'resc']:
if getattr(self, attr) is not None:
kwargs[attr] = getattr(self, attr) / getattr(other, attr)
else:
kwargs = {}
for attr in ['act', 'res', 'resc']:
if getattr(self, attr) is not None:
kwargs[attr] = getattr(self, attr) / other
return SparseAct(**kwargs)
def __rtruediv__(self, other) -> SparseAct:
if isinstance(other, SparseAct):
kwargs = {}
for attr in ['act', 'res', 'resc']:
if getattr(self, attr) is not None:
kwargs[attr] = other / getattr(self, attr)
else:
kwargs = {}
for attr in ['act', 'res', 'resc']:
if getattr(self, attr) is not None:
kwargs[attr] = other / getattr(self, attr)
return SparseAct(**kwargs)
def __neg__(self) -> SparseAct:
return self._map(lambda x, _: -x)
def __invert__(self) -> SparseAct:
return self._map(lambda x, _: ~x)
def __getitem__(self, index: int):
return self.act[index]
def __repr__(self):
if self.res is None:
return f"SparseAct(act={self.act}, resc={self.resc})"
if self.resc is None:
return f"SparseAct(act={self.act}, res={self.res})"
else:
raise ValueError("SparseAct has both residual and contracted residual. This is an unsupported state.")
def sum(self, dim=None):
kwargs = {}
for attr in ['act', 'res', 'resc']:
if getattr(self, attr) is not None:
kwargs[attr] = getattr(self, attr).sum(dim)
return SparseAct(**kwargs)
def mean(self, dim: int):
kwargs = {}
for attr in ['act', 'res', 'resc']:
if getattr(self, attr) is not None:
kwargs[attr] = getattr(self, attr).mean(dim)
return SparseAct(**kwargs)
@property
def grad(self):
kwargs = {}
for attribute in ['act', 'res', 'resc']:
if getattr(self, attribute) is not None:
kwargs[attribute] = getattr(self, attribute).grad
return SparseAct(**kwargs)
def clone(self):
kwargs = {}
for attribute in ['act', 'res', 'resc']:
if getattr(self, attribute) is not None:
kwargs[attribute] = getattr(self, attribute).clone()
return SparseAct(**kwargs)
@property
def value(self):
kwargs = {}
for attribute in ['act', 'res', 'resc']:
if getattr(self, attribute) is not None:
kwargs[attribute] = getattr(self, attribute).value
return SparseAct(**kwargs)
def save(self):
return self._map(lambda x, _: x.save())
def detach(self):
return self._map(lambda x, _: x.detach())
def to_tensor(self):
if self.resc is None:
assert self.res is not None
return t.cat([self.act, self.res], dim=-1)
if self.res is None:
assert self.resc is not None
return t.cat([self.act, self.resc], dim=-1)
raise ValueError("SparseAct has both residual and contracted residual. This is an unsupported state.")
def to(self, device):
for attr in ['act', 'res', 'resc']:
if getattr(self, attr) is not None:
setattr(self, attr, getattr(self, attr).to(device))
return self
def __eq__(self, other): # type: ignore
return self._map(lambda x, y: x == y, other)
def __gt__(self, other):
return self._map(lambda x, y: x > y, other)
def __lt__(self, other):
return self._map(lambda x, y: x < y, other)
def nonzero(self):
return self._map(lambda x, _: x.nonzero())
def squeeze(self, dim):
return self._map(lambda x, _: x.squeeze(dim=dim))
def expand_as(self, other):
return self._map(lambda x, y: x.expand_as(y), other)
def zeros_like(self):
return self._map(lambda x, _: t.zeros_like(x))
def ones_like(self):
return self._map(lambda x, _: t.ones_like(x))
def abs(self):
return self._map(lambda x, _: x.abs())