diff --git a/tests/test_ttok.py b/tests/test_ttok.py index 1db5fd6..995ec1e 100644 --- a/tests/test_ttok.py +++ b/tests/test_ttok.py @@ -61,6 +61,9 @@ def test_ttok_count_and_tokens(args, expected_length, expected_tokens): [b"86127", b"15682", b"48864", b"21990", b"38641", "--decode", "--tokens"], "[b'\\xe7\\xa7\\x81', b'\\xe3\\x81\\xaf', b'\\xe5\\xad\\xa6', b'\\xe7\\x94\\x9f', b'\\xe3\\x81\\xa7\\xe3\\x81\\x99']", ), + (["hello", "big", "world", "out", "there", "--tokens", "--chunksize", "2"], "[b'hello', b' big']\n[b' world', b' out']\n[b' there']"), + (["hello", "big", "world", "out", "there", "--chunksize", "2"], "hello big\n world out\n there"), + (["hello", "big", "world", "out", "there", "--encode", "--chunksize", "2"], "15339 2466\n1917 704\n1070"), ), ) def test_ttok_decode_encode_tokens(args, expected): diff --git a/ttok/cli.py b/ttok/cli.py index dfd8828..f9017c0 100644 --- a/ttok/cli.py +++ b/ttok/cli.py @@ -20,6 +20,7 @@ ) @click.option("as_tokens", "--tokens", is_flag=True, help="Output full tokens") @click.option("--allow-special", is_flag=True, help="Do not error on special tokens") +@click.option("--chunksize", type=int, help="Output chunks of this size") def cli( prompt, input, @@ -29,6 +30,7 @@ def cli( decode_tokens, as_tokens, allow_special, + chunksize, ): """ Count and truncate text based on tokens @@ -113,12 +115,23 @@ def cli( if truncate: tokens = tokens[:truncate] + if chunksize: + chunks = [tokens[i:i + chunksize] for i in range(0, len(tokens), chunksize)] + else: + chunks = [tokens] + if encode_tokens: if as_tokens: - click.echo(encoding.decode_tokens_bytes(tokens)) + for chunk in chunks: + click.echo(encoding.decode_tokens_bytes(chunk)) else: - click.echo(" ".join(str(t) for t in tokens)) + for chunk in chunks: + click.echo(" ".join(str(t) for t in chunk)) + elif chunksize: + for chunk in chunks: + click.echo(encoding.decode(chunk)) elif truncate: click.echo(encoding.decode(tokens), nl=False) else: click.echo(len(tokens)) +