Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions config/plano_config_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ properties:
agent_orchestration_model:
type: string
description: "Model name for the agent orchestrator (e.g., 'Plano-Orchestrator'). Must match a model in model_providers."
token_counting_strategy:
type: string
enum: [estimate, auto]
description: >
Strategy for counting input tokens used in metrics and rate limiting.
"estimate" (default): fast character-based approximation (~1 token per 4 chars).
"auto": uses the best available tokenizer for each provider (e.g., tiktoken for
OpenAI models), falling back to estimate for unsupported providers.
system_prompt:
type: string
prompt_targets:
Expand Down
10 changes: 10 additions & 0 deletions crates/common/src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,23 @@ pub struct Configuration {
pub state_storage: Option<StateStorageConfig>,
}

#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub enum TokenCountingStrategy {
#[default]
#[serde(rename = "estimate")]
Estimate,
#[serde(rename = "auto")]
Auto,
}

#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Overrides {
pub prompt_target_intent_matching_threshold: Option<f64>,
pub optimize_context_window: Option<bool>,
pub use_agent_orchestrator: Option<bool>,
pub llm_routing_model: Option<String>,
pub agent_orchestration_model: Option<String>,
pub token_counting_strategy: Option<TokenCountingStrategy>,
}

#[derive(Debug, Clone, Serialize, Deserialize, Default)]
Expand Down
35 changes: 26 additions & 9 deletions crates/llm_gateway/src/stream_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};

use crate::metrics::Metrics;
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
use common::configuration::{LlmProvider, LlmProviderType, Overrides, TokenCountingStrategy};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
Expand Down Expand Up @@ -48,7 +48,7 @@ pub struct StreamContext {
ttft_time: Option<u128>,
traceparent: Option<String>,
request_body_sent_time: Option<u128>,
_overrides: Rc<Option<Overrides>>,
overrides: Rc<Option<Overrides>>,
user_message: Option<String>,
upstream_status_code: Option<StatusCode>,
binary_frame_decoder: Option<BedrockBinaryFrameDecoder<bytes::BytesMut>>,
Expand All @@ -66,7 +66,7 @@ impl StreamContext {
) -> Self {
StreamContext {
metrics,
_overrides: overrides,
overrides,
ratelimit_selector: None,
streaming_response: false,
response_tokens: 0,
Expand Down Expand Up @@ -269,22 +269,39 @@ impl StreamContext {
model: &str,
json_string: &str,
) -> Result<(), ratelimit::Error> {
// Tokenize and record token count.
let token_count = tokenizer::token_count(model, json_string).unwrap_or(0);
let strategy = (*self.overrides)
.as_ref()
.and_then(|o| o.token_counting_strategy.clone())
.unwrap_or_default();

let (token_count, method) = match strategy {
TokenCountingStrategy::Auto => {
let provider_id = self.get_provider_id();
match provider_id {
ProviderId::OpenAI => (
tokenizer::token_count(model, json_string).unwrap_or(json_string.len() / 4),
"tiktoken",
),
// Future: add provider-specific tokenizers here
// ProviderId::Mistral => (mistral_tokenizer::count(...), "mistral"),
_ => (json_string.len() / 4, "estimate"),
}
}
TokenCountingStrategy::Estimate => (json_string.len() / 4, "estimate"),
};

debug!(
"request_id={}: token count, model='{}' input_tokens={}",
"request_id={}: token count, model='{}' input_tokens={} method={}",
self.request_identifier(),
model,
token_count
token_count,
method
);

// Record the token count to metrics.
self.metrics
.input_sequence_length
.record(token_count as u64);

// Check if rate limiting needs to be applied.
if let Some(selector) = self.ratelimit_selector.take() {
info!(
"request_id={}: ratelimit check, model='{}' selector='{}:{}'",
Expand Down
Loading