diff --git a/bindings/python/amrs/config.py b/bindings/python/amrs/config.py index 9ad1967..03dbd50 100644 --- a/bindings/python/amrs/config.py +++ b/bindings/python/amrs/config.py @@ -55,7 +55,7 @@ class ModelConfig(BasicModelConfig): ) weight: Optional[int] = Field( default=-1, - description="Weight of the model for ensemble methods. Only used if routing_mode is 'weighted'.", + description="Weight of the model for ensemble methods. Only used if router_mode is 'weighted'.", ) @@ -79,7 +79,7 @@ class RouterMode(str, Enum): class Config(BasicModelConfig): models: List[ModelConfig] = Field(description="List of model configurations") - routing_mode: RouterMode = Field( + router_mode: RouterMode = Field( default=RouterMode.RANDOM, description="Routing mode for the model, default is random.", ) diff --git a/examples/wrr.rs b/examples/wrr.rs index b72cf4f..63d5767 100644 --- a/examples/wrr.rs +++ b/examples/wrr.rs @@ -5,7 +5,7 @@ use tokio::runtime::Runtime; fn main() { let config = client::Config::builder() .provider("deepinfra") - .routing_mode(client::RouterMode::WRR) + .router_mode(client::RouterMode::WRR) .model( client::ModelConfig::builder() .name("deepseek-ai/DeepSeek-V3.2") diff --git a/src/client/client.rs b/src/client/client.rs index a156e0d..f465f1e 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -24,7 +24,7 @@ impl Client { Self { providers: providers, - router: router::construct_router(cfg.routing_mode, cfg.models), + router: router::construct_router(cfg.router_mode, cfg.models), } } @@ -81,7 +81,7 @@ mod tests { TestCase { name: "weighted round-robin router", config: Config::builder() - .routing_mode(RouterMode::WRR) + .router_mode(RouterMode::WRR) .models(vec![ crate::client::config::ModelConfig::builder() .name("model_a".to_string()) diff --git a/src/client/config.rs b/src/client/config.rs index 68329f1..e584f47 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -84,7 +84,7 @@ pub struct Config { pub(crate) provider: String, #[builder(default = "RouterMode::Random")] - pub(crate) routing_mode: RouterMode, + pub(crate) router_mode: RouterMode, #[builder(default = "vec![]")] pub(crate) models: Vec, } @@ -151,8 +151,8 @@ impl ConfigBuilder { } for model in self.models.as_ref().unwrap() { - if self.routing_mode.is_some() - && self.routing_mode.as_ref().unwrap() == &RouterMode::WRR + if self.router_mode.is_some() + && self.router_mode.as_ref().unwrap() == &RouterMode::WRR && model.weight <= 0 { return Err(format!( @@ -218,7 +218,7 @@ mod tests { assert!(valid_simplest_models_cfg.is_ok()); assert!(valid_simplest_models_cfg.as_ref().unwrap().provider == DEFAULT_PROVIDER); assert!(valid_simplest_models_cfg.as_ref().unwrap().base_url == None); - assert!(valid_simplest_models_cfg.as_ref().unwrap().routing_mode == RouterMode::Random); + assert!(valid_simplest_models_cfg.as_ref().unwrap().router_mode == RouterMode::Random); assert!(valid_simplest_models_cfg.as_ref().unwrap().models.len() == 1); assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].base_url == None); assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].provider == None); diff --git a/tests/client.rs b/tests/client.rs index 68fe94f..f30d1b5 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -57,7 +57,7 @@ mod tests { // case 3: multiple models with router. let config = client::Config::builder() .provider("faker") - .routing_mode(client::RouterMode::WRR) + .router_mode(client::RouterMode::WRR) .model( client::ModelConfig::builder() .name("gpt-3.5-turbo")