forked from TransformerLensOrg/TransformerLens
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_tl.py
More file actions
21 lines (17 loc) · 650 Bytes
/
test_tl.py
File metadata and controls
21 lines (17 loc) · 650 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformer_lens import HookedTransformer
import torch
from typing import cast
print("Loading model...")
model = HookedTransformer.from_pretrained("gpt2-small")
print("Running model...")
logits, activations = model.run_with_cache("Hello World")
# runtime type + repr
print("type(logits) =", type(logits))
print("repr(logits)[:200] =", repr(logits)[:200])
# safe checks and printing shape
if isinstance(logits, torch.Tensor):
print("logits.shape (runtime):", logits.shape)
else:
# cast for type-checkers (see next section)
logits = cast(torch.Tensor, logits)
print("After cast, logits.shape:", getattr(logits, "shape", None))