Skip to content

Commit 9fcf45a

Browse files
committed
adds HookedTransformer.generate_stream()
1 parent e65fafb commit 9fcf45a

File tree

3 files changed

+295
-17
lines changed

3 files changed

+295
-17
lines changed

demos/BERT.ipynb

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,10 @@
4343
"name": "stderr",
4444
"output_type": "stream",
4545
"text": [
46-
"/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_39188/4022418010.py:26: DeprecationWarning:\n",
47-
"\n",
48-
"`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
49-
"\n",
50-
"/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_39188/4022418010.py:27: DeprecationWarning:\n",
51-
"\n",
52-
"`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
53-
"\n"
46+
"/var/folders/g4/9w0tw_5j2r79y6qb2v43vls00000gn/T/ipykernel_2797/4022418010.py:26: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
47+
" ipython.magic(\"load_ext autoreload\")\n",
48+
"/var/folders/g4/9w0tw_5j2r79y6qb2v43vls00000gn/T/ipykernel_2797/4022418010.py:27: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
49+
" ipython.magic(\"autoreload 2\")\n"
5450
]
5551
}
5652
],
@@ -120,18 +116,18 @@
120116
{
121117
"data": {
122118
"text/html": [
123-
"<div id=\"circuits-vis-8c91db10-74f4\" style=\"margin: 15px 0;\"/>\n",
119+
"<div id=\"circuits-vis-77c428e0-6b15\" style=\"margin: 15px 0;\"/>\n",
124120
" <script crossorigin type=\"module\">\n",
125-
" import { render, Hello } from \"https://unpkg.com/circuitsvis@1.43.2/dist/cdn/esm.js\";\n",
121+
" import { render, Hello } from \"https://unpkg.com/circuitsvis@1.43.3/dist/cdn/esm.js\";\n",
126122
" render(\n",
127-
" \"circuits-vis-8c91db10-74f4\",\n",
123+
" \"circuits-vis-77c428e0-6b15\",\n",
128124
" Hello,\n",
129125
" {\"name\": \"Neel\"}\n",
130126
" )\n",
131127
" </script>"
132128
],
133129
"text/plain": [
134-
"<circuitsvis.utils.render.RenderedHTML at 0x13a9760d0>"
130+
"<circuitsvis.utils.render.RenderedHTML at 0x11fa11690>"
135131
]
136132
},
137133
"execution_count": 3,
@@ -150,7 +146,18 @@
150146
"cell_type": "code",
151147
"execution_count": 4,
152148
"metadata": {},
153-
"outputs": [],
149+
"outputs": [
150+
{
151+
"name": "stderr",
152+
"output_type": "stream",
153+
"text": [
154+
"/Users/anthonyduong/Code/TransformerLens/.venv/lib/python3.11/site-packages/transformers/utils/hub.py:128: FutureWarning:\n",
155+
"\n",
156+
"Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n",
157+
"\n"
158+
]
159+
}
160+
],
154161
"source": [
155162
"# Import stuff\n",
156163
"import torch\n",
@@ -168,7 +175,7 @@
168175
{
169176
"data": {
170177
"text/plain": [
171-
"<torch.autograd.grad_mode.set_grad_enabled at 0x2a285a790>"
178+
"<torch.autograd.grad_mode.set_grad_enabled at 0x31253f9d0>"
172179
]
173180
},
174181
"execution_count": 5,
@@ -208,7 +215,7 @@
208215
"output_type": "stream",
209216
"text": [
210217
"Moving model to device: mps\n",
211-
"Loaded pretrained model bert-base-cased into HookedTransformer\n"
218+
"Loaded pretrained model bert-base-cased into HookedEncoder\n"
212219
]
213220
}
214221
],
@@ -366,7 +373,7 @@
366373
],
367374
"metadata": {
368375
"kernelspec": {
369-
"display_name": "Python 3",
376+
"display_name": ".venv",
370377
"language": "python",
371378
"name": "python3"
372379
},
@@ -380,7 +387,7 @@
380387
"name": "python",
381388
"nbconvert_exporter": "python",
382389
"pygments_lexer": "ipython3",
383-
"version": "3.10.15"
390+
"version": "3.11.9"
384391
},
385392
"orig_nbformat": 4
386393
},

tests/acceptance/test_hooked_transformer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,44 @@ def test_bloom_similarity_with_hf_model_with_kv_cache_activated():
195195
assert output_tf == output_hf_str
196196

197197

198+
def test_bloom_similarity_with_hf_model_with_kv_cache_activated_stream():
199+
tf_model = HookedTransformer.from_pretrained(
200+
"bigscience/bloom-560m", default_prepend_bos=False, device="cpu"
201+
)
202+
203+
hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
204+
hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
205+
206+
gen = tf_model.generate_stream(
207+
text,
208+
do_sample=False,
209+
use_past_kv_cache=True,
210+
verbose=False,
211+
max_new_tokens=10,
212+
max_tokens_per_yield=10,
213+
)
214+
215+
# Exhaust the generator to capture its final return value.
216+
while True:
217+
try:
218+
next(gen)
219+
except StopIteration as e:
220+
final_output = e.value
221+
break
222+
223+
hf_input_ids = hf_tokenizer(text, return_tensors="pt").input_ids
224+
output_hf_tokens = hf_model.generate(
225+
hf_input_ids,
226+
do_sample=False,
227+
max_new_tokens=10,
228+
)
229+
output_hf_str = hf_tokenizer.decode(output_hf_tokens[0], skip_special_tokens=True)
230+
231+
assert (
232+
final_output == output_hf_str
233+
), f"\nStreaming output: {final_output}\nHF output: {output_hf_str}"
234+
235+
198236
def check_norm_folding(
199237
model_name,
200238
hf_model=None,

transformer_lens/HookedTransformer.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import logging
1313
import os
14+
from collections.abc import Generator
1415
from typing import (
1516
Dict,
1617
List,
@@ -2340,6 +2341,238 @@ def generate(
23402341
else:
23412342
return embeds
23422343

2344+
@torch.inference_mode()
2345+
def generate_stream(
2346+
self,
2347+
input: Union[str, Float[torch.Tensor, "batch pos"]] = "",
2348+
max_new_tokens: int = 10,
2349+
max_tokens_per_yield: int = 25,
2350+
stop_at_eos: bool = True,
2351+
eos_token_id: Optional[int] = None,
2352+
do_sample: bool = True,
2353+
top_k: Optional[int] = None,
2354+
top_p: Optional[float] = None,
2355+
temperature: float = 1.0,
2356+
freq_penalty: float = 0.0,
2357+
use_past_kv_cache: bool = True,
2358+
prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
2359+
padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
2360+
return_type: Optional[str] = "input",
2361+
verbose: bool = True,
2362+
) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]:
2363+
"""Stream tokens from the Model as they are generated.
2364+
2365+
Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached,
2366+
yielding batches of tokens progressively during generation rather than waiting for the entire
2367+
sequence to be generated.
2368+
2369+
To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
2370+
(by producing an EOT token), we keep running the model on the entire batch, but throw away
2371+
the output for a finished sequence and just keep adding EOTs to pad.
2372+
2373+
This supports entering a single string, but not a list of strings - if the strings don't
2374+
tokenize to exactly the same length, this gets messy. If that functionality is needed,
2375+
convert them to a batch of tokens and input that instead.
2376+
2377+
Args:
2378+
input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch,
2379+
pos]) or a text string (this will be converted to a batch of tokens with batch size
2380+
1).
2381+
max_new_tokens (int): Maximum number of tokens to generate.
2382+
max_tokens_per_yield (int): Maximum number of tokens to accumulate before yielding.
2383+
Controls how frequently the function yields tokens during generation.
2384+
stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
2385+
eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
2386+
of sentence. If None, use the tokenizer's eos_token_id - required if using
2387+
stop_at_eos. It's also possible to provide a list of token IDs (not just the
2388+
eos_token_id), in which case the generation will stop when any of them are output
2389+
(useful e.g. for stable_lm).
2390+
do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
2391+
greedy search (take the max logit each time).
2392+
top_k (int): Number of tokens to sample from. If None, sample from all tokens.
2393+
top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
2394+
we take the top tokens with cumulative probability >= top_p.
2395+
temperature (float): Temperature for sampling. Higher values will make the model more
2396+
random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
2397+
sampling from a uniform distribution).
2398+
freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
2399+
tokens. Higher values will make the model more random.
2400+
use_past_kv_cache (bool): If True, create and use cache to speed up generation.
2401+
prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2402+
the BOS token to the input (applicable when input is a string). Defaults to None,
2403+
implying usage of self.cfg.default_prepend_bos (default is True unless specified
2404+
otherwise). Pass True or False to override the default.
2405+
padding_side (Union[Literal["left", "right"], None], optional): Overrides
2406+
self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2407+
strings of different lengths.
2408+
return_type (Optional[str]): The type of the output to return - either a string (str),
2409+
a tensor of tokens (tensor) or whatever the format of the input was (input).
2410+
verbose (bool): If True, show tqdm progress bars for generation.
2411+
2412+
Yields:
2413+
outputs (Union[Int[torch.Tensor, "batch"], str]): Batches of generated tokens, yielded
2414+
progressively during generation. Each yield contains accumulated tokens since the last
2415+
yield, up to max_tokens_per_yield.
2416+
"""
2417+
2418+
with utils.LocallyOverridenDefaults(
2419+
self, prepend_bos=prepend_bos, padding_side=padding_side
2420+
):
2421+
if type(input) == str:
2422+
# If text, convert to tokens (batch_size=1)
2423+
assert (
2424+
self.tokenizer is not None
2425+
), "Must provide a tokenizer if passing a string to the model"
2426+
tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
2427+
else:
2428+
tokens = input
2429+
2430+
if return_type == "input":
2431+
if type(input) == str:
2432+
return_type = "str"
2433+
else:
2434+
return_type = "tensor"
2435+
2436+
assert isinstance(tokens, torch.Tensor)
2437+
batch_size, ctx_length = tokens.shape
2438+
device = devices.get_device_for_block_index(0, self.cfg)
2439+
tokens = tokens.to(device)
2440+
if use_past_kv_cache:
2441+
past_kv_cache = HookedTransformerKeyValueCache.init_cache(
2442+
self.cfg, self.cfg.device, batch_size
2443+
)
2444+
else:
2445+
past_kv_cache = None
2446+
2447+
stop_tokens: List[int] = []
2448+
eos_token_for_padding = 0
2449+
assert self.tokenizer is not None
2450+
if stop_at_eos:
2451+
tokenizer_has_eos_token = (
2452+
self.tokenizer is not None and self.tokenizer.eos_token_id is not None
2453+
)
2454+
if eos_token_id is None:
2455+
assert (
2456+
tokenizer_has_eos_token
2457+
), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
2458+
2459+
eos_token_id = self.tokenizer.eos_token_id
2460+
2461+
if isinstance(eos_token_id, int):
2462+
stop_tokens = [eos_token_id]
2463+
eos_token_for_padding = eos_token_id
2464+
else:
2465+
# eos_token_id is a Sequence (e.g. list or tuple)
2466+
stop_tokens = eos_token_id
2467+
eos_token_for_padding = (
2468+
self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0]
2469+
)
2470+
2471+
# An array to track which sequences in the batch have finished.
2472+
finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2473+
2474+
# Currently nothing in HookedTransformer changes with eval, but this is here in case
2475+
# that changes in the future.
2476+
self.eval()
2477+
for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2478+
# While generating, we keep generating logits, throw away all but the final logits,
2479+
# and then use those logits to sample from the distribution We keep adding the
2480+
# sampled tokens to the end of tokens.
2481+
if use_past_kv_cache:
2482+
# We just take the final tokens, as a [batch, 1] tensor
2483+
if index > 0:
2484+
logits = self.forward(
2485+
tokens[:, -1:],
2486+
return_type="logits",
2487+
prepend_bos=prepend_bos,
2488+
padding_side=padding_side,
2489+
past_kv_cache=past_kv_cache,
2490+
)
2491+
else:
2492+
logits = self.forward(
2493+
tokens,
2494+
return_type="logits",
2495+
prepend_bos=prepend_bos,
2496+
padding_side=padding_side,
2497+
past_kv_cache=past_kv_cache,
2498+
)
2499+
else:
2500+
# We input the entire sequence, as a [batch, pos] tensor, since we aren't using
2501+
# the cache.
2502+
logits = self.forward(
2503+
tokens,
2504+
return_type="logits",
2505+
prepend_bos=prepend_bos,
2506+
padding_side=padding_side,
2507+
)
2508+
final_logits = logits[:, -1, :]
2509+
2510+
if do_sample:
2511+
sampled_tokens = utils.sample_logits(
2512+
final_logits,
2513+
top_k=top_k,
2514+
top_p=top_p,
2515+
temperature=temperature,
2516+
freq_penalty=freq_penalty,
2517+
tokens=tokens,
2518+
).to(devices.get_device_for_block_index(0, self.cfg))
2519+
else:
2520+
sampled_tokens = final_logits.argmax(-1).to(
2521+
devices.get_device_for_block_index(0, self.cfg)
2522+
)
2523+
2524+
if stop_at_eos:
2525+
# For all unfinished sequences, add on the next token. If a sequence was
2526+
# finished, throw away the generated token and add eos_token_for_padding
2527+
# instead.
2528+
sampled_tokens[finished_sequences] = eos_token_for_padding
2529+
finished_sequences.logical_or_(
2530+
torch.isin(
2531+
sampled_tokens.to(self.cfg.device),
2532+
torch.tensor(stop_tokens).to(self.cfg.device),
2533+
)
2534+
)
2535+
2536+
new_tokens = sampled_tokens.unsqueeze(-1)
2537+
2538+
# Accumulate tokens until we hit max_tokens_per_yield
2539+
if index == 0:
2540+
accumulated_tokens = torch.cat([tokens, new_tokens], dim=-1)
2541+
tokens_since_last_yield = accumulated_tokens.shape[1]
2542+
else:
2543+
if accumulated_tokens is None:
2544+
accumulated_tokens = new_tokens
2545+
else:
2546+
accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1)
2547+
tokens_since_last_yield += 1
2548+
2549+
if tokens_since_last_yield >= max_tokens_per_yield:
2550+
yield accumulated_tokens
2551+
tokens_since_last_yield = 0
2552+
accumulated_tokens = None
2553+
2554+
tokens = torch.cat([tokens, new_tokens], dim=-1)
2555+
2556+
if stop_at_eos and finished_sequences.all():
2557+
# Yield any remaining accumulated tokens before breaking
2558+
if accumulated_tokens is not None:
2559+
yield accumulated_tokens
2560+
break
2561+
2562+
# Only yield remaining tokens if we didn't already yield them in the break case
2563+
if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()):
2564+
yield accumulated_tokens
2565+
2566+
if return_type == "str":
2567+
if self.cfg.default_prepend_bos:
2568+
# If we prepended a BOS token, remove it when returning output.
2569+
return self.tokenizer.decode(tokens[0, 1:])
2570+
else:
2571+
return self.tokenizer.decode(tokens[0])
2572+
2573+
else:
2574+
return tokens
2575+
23432576
# Give access to all weights as properties.
23442577
@property
23452578
def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:

0 commit comments

Comments
 (0)