forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelpers.py
More file actions
20 lines (18 loc) · 752 Bytes
/
helpers.py
File metadata and controls
20 lines (18 loc) · 752 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from tinygrad.device import JITRunner
from tinygrad.nn.state import get_parameters
from tinygrad import Tensor
from tinygrad.helpers import Context
def derandomize_model(model):
with Context(GRAPH=0):
for p in get_parameters(model):
p.lazydata = Tensor.empty(p.shape, device=p.device, dtype=p.dtype).lazydata
p.realize()
def assert_jit_cache_len(fxn, expected_len):
assert len(fxn.jit_cache) > 0
if issubclass(type(fxn.jit_cache[0].prg), JITRunner):
assert len(fxn.jit_cache) == expected_len
else:
assert len(fxn.jit_cache) == 1
# until we have a better way of typing the prg in JitItem
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len