|
11 | 11 |
|
12 | 12 | import logging |
13 | 13 | import os |
| 14 | +from collections.abc import Generator |
14 | 15 | from typing import ( |
15 | 16 | Dict, |
16 | 17 | List, |
@@ -2340,6 +2341,238 @@ def generate( |
2340 | 2341 | else: |
2341 | 2342 | return embeds |
2342 | 2343 |
|
| 2344 | + @torch.inference_mode() |
| 2345 | + def generate_stream( |
| 2346 | + self, |
| 2347 | + input: Union[str, Float[torch.Tensor, "batch pos"]] = "", |
| 2348 | + max_new_tokens: int = 10, |
| 2349 | + max_tokens_per_yield: int = 25, |
| 2350 | + stop_at_eos: bool = True, |
| 2351 | + eos_token_id: Optional[int] = None, |
| 2352 | + do_sample: bool = True, |
| 2353 | + top_k: Optional[int] = None, |
| 2354 | + top_p: Optional[float] = None, |
| 2355 | + temperature: float = 1.0, |
| 2356 | + freq_penalty: float = 0.0, |
| 2357 | + use_past_kv_cache: bool = True, |
| 2358 | + prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, |
| 2359 | + padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, |
| 2360 | + return_type: Optional[str] = "input", |
| 2361 | + verbose: bool = True, |
| 2362 | + ) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]: |
| 2363 | + """Stream tokens from the Model as they are generated. |
| 2364 | +
|
| 2365 | + Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached, |
| 2366 | + yielding batches of tokens progressively during generation rather than waiting for the entire |
| 2367 | + sequence to be generated. |
| 2368 | +
|
| 2369 | + To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish |
| 2370 | + (by producing an EOT token), we keep running the model on the entire batch, but throw away |
| 2371 | + the output for a finished sequence and just keep adding EOTs to pad. |
| 2372 | +
|
| 2373 | + This supports entering a single string, but not a list of strings - if the strings don't |
| 2374 | + tokenize to exactly the same length, this gets messy. If that functionality is needed, |
| 2375 | + convert them to a batch of tokens and input that instead. |
| 2376 | +
|
| 2377 | + Args: |
| 2378 | + input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, |
| 2379 | + pos]) or a text string (this will be converted to a batch of tokens with batch size |
| 2380 | + 1). |
| 2381 | + max_new_tokens (int): Maximum number of tokens to generate. |
| 2382 | + max_tokens_per_yield (int): Maximum number of tokens to accumulate before yielding. |
| 2383 | + Controls how frequently the function yields tokens during generation. |
| 2384 | + stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. |
| 2385 | + eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end |
| 2386 | + of sentence. If None, use the tokenizer's eos_token_id - required if using |
| 2387 | + stop_at_eos. It's also possible to provide a list of token IDs (not just the |
| 2388 | + eos_token_id), in which case the generation will stop when any of them are output |
| 2389 | + (useful e.g. for stable_lm). |
| 2390 | + do_sample (bool): If True, sample from the model's output distribution. Otherwise, use |
| 2391 | + greedy search (take the max logit each time). |
| 2392 | + top_k (int): Number of tokens to sample from. If None, sample from all tokens. |
| 2393 | + top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, |
| 2394 | + we take the top tokens with cumulative probability >= top_p. |
| 2395 | + temperature (float): Temperature for sampling. Higher values will make the model more |
| 2396 | + random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is |
| 2397 | + sampling from a uniform distribution). |
| 2398 | + freq_penalty (float): Frequency penalty for sampling - how much to penalise previous |
| 2399 | + tokens. Higher values will make the model more random. |
| 2400 | + use_past_kv_cache (bool): If True, create and use cache to speed up generation. |
| 2401 | + prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend |
| 2402 | + the BOS token to the input (applicable when input is a string). Defaults to None, |
| 2403 | + implying usage of self.cfg.default_prepend_bos (default is True unless specified |
| 2404 | + otherwise). Pass True or False to override the default. |
| 2405 | + padding_side (Union[Literal["left", "right"], None], optional): Overrides |
| 2406 | + self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple |
| 2407 | + strings of different lengths. |
| 2408 | + return_type (Optional[str]): The type of the output to return - either a string (str), |
| 2409 | + a tensor of tokens (tensor) or whatever the format of the input was (input). |
| 2410 | + verbose (bool): If True, show tqdm progress bars for generation. |
| 2411 | +
|
| 2412 | + Yields: |
| 2413 | + outputs (Union[Int[torch.Tensor, "batch"], str]): Batches of generated tokens, yielded |
| 2414 | + progressively during generation. Each yield contains accumulated tokens since the last |
| 2415 | + yield, up to max_tokens_per_yield. |
| 2416 | + """ |
| 2417 | + |
| 2418 | + with utils.LocallyOverridenDefaults( |
| 2419 | + self, prepend_bos=prepend_bos, padding_side=padding_side |
| 2420 | + ): |
| 2421 | + if type(input) == str: |
| 2422 | + # If text, convert to tokens (batch_size=1) |
| 2423 | + assert ( |
| 2424 | + self.tokenizer is not None |
| 2425 | + ), "Must provide a tokenizer if passing a string to the model" |
| 2426 | + tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) |
| 2427 | + else: |
| 2428 | + tokens = input |
| 2429 | + |
| 2430 | + if return_type == "input": |
| 2431 | + if type(input) == str: |
| 2432 | + return_type = "str" |
| 2433 | + else: |
| 2434 | + return_type = "tensor" |
| 2435 | + |
| 2436 | + assert isinstance(tokens, torch.Tensor) |
| 2437 | + batch_size, ctx_length = tokens.shape |
| 2438 | + device = devices.get_device_for_block_index(0, self.cfg) |
| 2439 | + tokens = tokens.to(device) |
| 2440 | + if use_past_kv_cache: |
| 2441 | + past_kv_cache = HookedTransformerKeyValueCache.init_cache( |
| 2442 | + self.cfg, self.cfg.device, batch_size |
| 2443 | + ) |
| 2444 | + else: |
| 2445 | + past_kv_cache = None |
| 2446 | + |
| 2447 | + stop_tokens: List[int] = [] |
| 2448 | + eos_token_for_padding = 0 |
| 2449 | + assert self.tokenizer is not None |
| 2450 | + if stop_at_eos: |
| 2451 | + tokenizer_has_eos_token = ( |
| 2452 | + self.tokenizer is not None and self.tokenizer.eos_token_id is not None |
| 2453 | + ) |
| 2454 | + if eos_token_id is None: |
| 2455 | + assert ( |
| 2456 | + tokenizer_has_eos_token |
| 2457 | + ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" |
| 2458 | + |
| 2459 | + eos_token_id = self.tokenizer.eos_token_id |
| 2460 | + |
| 2461 | + if isinstance(eos_token_id, int): |
| 2462 | + stop_tokens = [eos_token_id] |
| 2463 | + eos_token_for_padding = eos_token_id |
| 2464 | + else: |
| 2465 | + # eos_token_id is a Sequence (e.g. list or tuple) |
| 2466 | + stop_tokens = eos_token_id |
| 2467 | + eos_token_for_padding = ( |
| 2468 | + self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] |
| 2469 | + ) |
| 2470 | + |
| 2471 | + # An array to track which sequences in the batch have finished. |
| 2472 | + finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) |
| 2473 | + |
| 2474 | + # Currently nothing in HookedTransformer changes with eval, but this is here in case |
| 2475 | + # that changes in the future. |
| 2476 | + self.eval() |
| 2477 | + for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): |
| 2478 | + # While generating, we keep generating logits, throw away all but the final logits, |
| 2479 | + # and then use those logits to sample from the distribution We keep adding the |
| 2480 | + # sampled tokens to the end of tokens. |
| 2481 | + if use_past_kv_cache: |
| 2482 | + # We just take the final tokens, as a [batch, 1] tensor |
| 2483 | + if index > 0: |
| 2484 | + logits = self.forward( |
| 2485 | + tokens[:, -1:], |
| 2486 | + return_type="logits", |
| 2487 | + prepend_bos=prepend_bos, |
| 2488 | + padding_side=padding_side, |
| 2489 | + past_kv_cache=past_kv_cache, |
| 2490 | + ) |
| 2491 | + else: |
| 2492 | + logits = self.forward( |
| 2493 | + tokens, |
| 2494 | + return_type="logits", |
| 2495 | + prepend_bos=prepend_bos, |
| 2496 | + padding_side=padding_side, |
| 2497 | + past_kv_cache=past_kv_cache, |
| 2498 | + ) |
| 2499 | + else: |
| 2500 | + # We input the entire sequence, as a [batch, pos] tensor, since we aren't using |
| 2501 | + # the cache. |
| 2502 | + logits = self.forward( |
| 2503 | + tokens, |
| 2504 | + return_type="logits", |
| 2505 | + prepend_bos=prepend_bos, |
| 2506 | + padding_side=padding_side, |
| 2507 | + ) |
| 2508 | + final_logits = logits[:, -1, :] |
| 2509 | + |
| 2510 | + if do_sample: |
| 2511 | + sampled_tokens = utils.sample_logits( |
| 2512 | + final_logits, |
| 2513 | + top_k=top_k, |
| 2514 | + top_p=top_p, |
| 2515 | + temperature=temperature, |
| 2516 | + freq_penalty=freq_penalty, |
| 2517 | + tokens=tokens, |
| 2518 | + ).to(devices.get_device_for_block_index(0, self.cfg)) |
| 2519 | + else: |
| 2520 | + sampled_tokens = final_logits.argmax(-1).to( |
| 2521 | + devices.get_device_for_block_index(0, self.cfg) |
| 2522 | + ) |
| 2523 | + |
| 2524 | + if stop_at_eos: |
| 2525 | + # For all unfinished sequences, add on the next token. If a sequence was |
| 2526 | + # finished, throw away the generated token and add eos_token_for_padding |
| 2527 | + # instead. |
| 2528 | + sampled_tokens[finished_sequences] = eos_token_for_padding |
| 2529 | + finished_sequences.logical_or_( |
| 2530 | + torch.isin( |
| 2531 | + sampled_tokens.to(self.cfg.device), |
| 2532 | + torch.tensor(stop_tokens).to(self.cfg.device), |
| 2533 | + ) |
| 2534 | + ) |
| 2535 | + |
| 2536 | + new_tokens = sampled_tokens.unsqueeze(-1) |
| 2537 | + |
| 2538 | + # Accumulate tokens until we hit max_tokens_per_yield |
| 2539 | + if index == 0: |
| 2540 | + accumulated_tokens = torch.cat([tokens, new_tokens], dim=-1) |
| 2541 | + tokens_since_last_yield = accumulated_tokens.shape[1] |
| 2542 | + else: |
| 2543 | + if accumulated_tokens is None: |
| 2544 | + accumulated_tokens = new_tokens |
| 2545 | + else: |
| 2546 | + accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1) |
| 2547 | + tokens_since_last_yield += 1 |
| 2548 | + |
| 2549 | + if tokens_since_last_yield >= max_tokens_per_yield: |
| 2550 | + yield accumulated_tokens |
| 2551 | + tokens_since_last_yield = 0 |
| 2552 | + accumulated_tokens = None |
| 2553 | + |
| 2554 | + tokens = torch.cat([tokens, new_tokens], dim=-1) |
| 2555 | + |
| 2556 | + if stop_at_eos and finished_sequences.all(): |
| 2557 | + # Yield any remaining accumulated tokens before breaking |
| 2558 | + if accumulated_tokens is not None: |
| 2559 | + yield accumulated_tokens |
| 2560 | + break |
| 2561 | + |
| 2562 | + # Only yield remaining tokens if we didn't already yield them in the break case |
| 2563 | + if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()): |
| 2564 | + yield accumulated_tokens |
| 2565 | + |
| 2566 | + if return_type == "str": |
| 2567 | + if self.cfg.default_prepend_bos: |
| 2568 | + # If we prepended a BOS token, remove it when returning output. |
| 2569 | + return self.tokenizer.decode(tokens[0, 1:]) |
| 2570 | + else: |
| 2571 | + return self.tokenizer.decode(tokens[0]) |
| 2572 | + |
| 2573 | + else: |
| 2574 | + return tokens |
| 2575 | + |
2343 | 2576 | # Give access to all weights as properties. |
2344 | 2577 | @property |
2345 | 2578 | def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: |
|
0 commit comments