diff --git a/athena/utils/llm_util.py b/athena/utils/llm_util.py index f7c40f5..db02477 100644 --- a/athena/utils/llm_util.py +++ b/athena/utils/llm_util.py @@ -6,14 +6,39 @@ def str_token_counter(text: str) -> int: + """Counts the number of tokens in a string using tiktoken's o200k_base encoding. + + Args: + text: The input string to count tokens for. + + Returns: + The number of tokens in the input string. + """ enc = tiktoken.get_encoding("o200k_base") return len(enc.encode(text)) def tiktoken_counter(messages: Sequence[BaseMessage]) -> int: - """Approximately reproduce https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + """Counts tokens across multiple message types using tiktoken tokenization. + + Approximately reproduces the token counting methodology from OpenAI's cookbook: + https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + + Args: + messages: A sequence of BaseMessage objects (HumanMessage, AIMessage, + ToolMessage, or SystemMessage) to count tokens for. + + Returns: + The total number of tokens across all messages, including overhead tokens. + + Raises: + ValueError: If an unsupported message type is encountered. - For simplicity only supports str Message.contents. + Notes: + - Uses a fixed overhead of 3 tokens for reply priming + - Adds 3 tokens per message for message formatting + - Adds 1 token per message name if present + - For simplicity, only supports string message contents """ output_parser = StrOutputParser() num_tokens = 3 # every reply is primed with <|start|>assistant<|message|>