Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions athena/utils/llm_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,39 @@


def str_token_counter(text: str) -> int:
"""Counts the number of tokens in a string using tiktoken's o200k_base encoding.

Args:
text: The input string to count tokens for.

Returns:
The number of tokens in the input string.
"""
enc = tiktoken.get_encoding("o200k_base")
return len(enc.encode(text))


def tiktoken_counter(messages: Sequence[BaseMessage]) -> int:
"""Approximately reproduce https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
"""Counts tokens across multiple message types using tiktoken tokenization.

Approximately reproduces the token counting methodology from OpenAI's cookbook:
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb

Args:
messages: A sequence of BaseMessage objects (HumanMessage, AIMessage,
ToolMessage, or SystemMessage) to count tokens for.

Returns:
The total number of tokens across all messages, including overhead tokens.

Raises:
ValueError: If an unsupported message type is encountered.

For simplicity only supports str Message.contents.
Notes:
- Uses a fixed overhead of 3 tokens for reply priming
- Adds 3 tokens per message for message formatting
- Adds 1 token per message name if present
- For simplicity, only supports string message contents
"""
output_parser = StrOutputParser()
num_tokens = 3 # every reply is primed with <|start|>assistant<|message|>
Expand Down