Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion ttok/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import click
import sys
import tiktoken
from itertools import zip_longest


@click.command()
Expand All @@ -10,9 +11,10 @@
@click.option(
"-t", "--truncate", "truncate", type=int, help="Truncate to this many tokens"
)
@click.option("-s", "--split", "split_file", type=click.Path(), help="Split input into multiple files")
@click.option("-m", "--model", default="gpt-3.5-turbo", help="Which model to use")
@click.option("output_tokens", "--tokens", is_flag=True, help="Output token integers")
def cli(prompt, input, truncate, model, output_tokens):
def cli(prompt, input, truncate, split_file, model, output_tokens):
"""
Count and truncate text based on tokens

Expand Down Expand Up @@ -51,6 +53,21 @@ def cli(prompt, input, truncate, model, output_tokens):
text = input_text
# Tokenize it
tokens = encoding.encode(text)

if split_file and truncate:
# Ensure the truncate value is valid
if truncate <= 0:
raise click.ClickException(f"Invalid truncate value: {truncate}")
# Split the tokens into groups of the specified size
chunks = grouper(tokens, truncate)
# Write each chunk to a separate file
for i, chunk in enumerate(chunks):
with open(f"{split_file}.part{i}", "w") as f:
chunk = [token for token in chunk if token is not None]
decoded_chunk = encoding.decode(chunk)
f.write(decoded_chunk)
return # Exit the function early, we don't need to do anything else in this case

if truncate:
tokens = tokens[:truncate]

Expand All @@ -60,3 +77,9 @@ def cli(prompt, input, truncate, model, output_tokens):
click.echo(encoding.decode(tokens), nl=False)
else:
click.echo(len(tokens))

def grouper(iterable, n, fillvalue=None):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return zip_longest(fillvalue=fillvalue, *args)