From 7b5b981de71709eb682a35d81c03ca2d5642d0fa Mon Sep 17 00:00:00 2001 From: dongxu Date: Thu, 18 May 2023 15:27:19 -0700 Subject: [PATCH] Add split option to token counting CLI This commit introduces a new option (-s, --split) to the command-line interface of the token counting script. The new option allows users to split their input text into multiple files, each containing a specific number of tokens. When the split option is used in conjunction with the truncate option (-t, --truncate), the script splits the tokens into groups of the specified size and writes each group into a separate file. The output files are named using the provided filename and are appended with '.part' followed by the chunk index. A new helper function 'grouper' has been added to divide the iterable (tokens) into fixed-size chunks. The split function ensures that each chunk is decoded before being written to its respective output file. --- ttok/cli.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/ttok/cli.py b/ttok/cli.py index 73e1652..72e2a06 100644 --- a/ttok/cli.py +++ b/ttok/cli.py @@ -1,6 +1,7 @@ import click import sys import tiktoken +from itertools import zip_longest @click.command() @@ -10,9 +11,10 @@ @click.option( "-t", "--truncate", "truncate", type=int, help="Truncate to this many tokens" ) +@click.option("-s", "--split", "split_file", type=click.Path(), help="Split input into multiple files") @click.option("-m", "--model", default="gpt-3.5-turbo", help="Which model to use") @click.option("output_tokens", "--tokens", is_flag=True, help="Output token integers") -def cli(prompt, input, truncate, model, output_tokens): +def cli(prompt, input, truncate, split_file, model, output_tokens): """ Count and truncate text based on tokens @@ -51,6 +53,21 @@ def cli(prompt, input, truncate, model, output_tokens): text = input_text # Tokenize it tokens = encoding.encode(text) + + if split_file and truncate: + # Ensure the truncate value is valid + if truncate <= 0: + raise click.ClickException(f"Invalid truncate value: {truncate}") + # Split the tokens into groups of the specified size + chunks = grouper(tokens, truncate) + # Write each chunk to a separate file + for i, chunk in enumerate(chunks): + with open(f"{split_file}.part{i}", "w") as f: + chunk = [token for token in chunk if token is not None] + decoded_chunk = encoding.decode(chunk) + f.write(decoded_chunk) + return # Exit the function early, we don't need to do anything else in this case + if truncate: tokens = tokens[:truncate] @@ -60,3 +77,9 @@ def cli(prompt, input, truncate, model, output_tokens): click.echo(encoding.decode(tokens), nl=False) else: click.echo(len(tokens)) + +def grouper(iterable, n, fillvalue=None): + "Collect data into fixed-length chunks or blocks" + # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" + args = [iter(iterable)] * n + return zip_longest(fillvalue=fillvalue, *args)