diff --git a/rig/rig-core/src/providers/ollama.rs b/rig/rig-core/src/providers/ollama.rs index 09abd9d85..20bac5e96 100644 --- a/rig/rig-core/src/providers/ollama.rs +++ b/rig/rig-core/src/providers/ollama.rs @@ -6,10 +6,16 @@ //! use rig::completion::Prompt; //! use rig::providers::ollama; //! -//! // Create a new Ollama client (defaults to http://localhost:11434) -//! // In the case of ollama, no API key is necessary, so we use the `Nothing` struct +//! // Create a new Ollama client (defaults to http://localhost:11434, no auth) //! let client = ollama::Client::new(Nothing).unwrap(); //! +//! // Or connect to a remote/proxied Ollama instance with authentication +//! let client = ollama::Client::builder() +//! .api_key("my-secret-key") +//! .base_url("http://remote-ollama:11434") +//! .build() +//! .unwrap(); +//! //! // Create an agent with a preamble //! let comedian_agent = client //! .agent("qwen2.5:14b") @@ -32,7 +38,8 @@ //! let extractor = client.extractor::("llama3.2").build(); //! ``` use crate::client::{ - self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient, + self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, + ProviderClient, }; use crate::completion::{GetTokenUsage, Usage}; use crate::http_client::{self, HttpClientExt}; @@ -58,6 +65,45 @@ use tracing_futures::Instrument; const OLLAMA_API_BASE_URL: &str = "http://localhost:11434"; +/// Optional API key for Ollama. By default Ollama requires no authentication, +/// but proxied or secured deployments may require a Bearer token. +#[derive(Debug, Default, Clone)] +pub struct OllamaApiKey(Option); + +impl ApiKey for OllamaApiKey { + fn into_header( + self, + ) -> Option> { + self.0.map(http_client::make_auth_header) + } +} + +impl From for OllamaApiKey { + fn from(_: Nothing) -> Self { + Self(None) + } +} + +impl From for OllamaApiKey { + fn from(key: String) -> Self { + if key.is_empty() { + Self(None) + } else { + Self(Some(key)) + } + } +} + +impl From<&str> for OllamaApiKey { + fn from(key: &str) -> Self { + if key.is_empty() { + Self(None) + } else { + Self(Some(key.to_owned())) + } + } +} + #[derive(Debug, Default, Clone, Copy)] pub struct OllamaExt; @@ -88,7 +134,7 @@ impl ProviderBuilder for OllamaBuilder { = OllamaExt where H: HttpClientExt; - type ApiKey = Nothing; + type ApiKey = OllamaApiKey; const BASE_URL: &'static str = OLLAMA_API_BASE_URL; @@ -103,23 +149,31 @@ impl ProviderBuilder for OllamaBuilder { } pub type Client = client::Client; -pub type ClientBuilder = client::ClientBuilder; +pub type ClientBuilder = client::ClientBuilder; impl ProviderClient for Client { - type Input = Nothing; + type Input = OllamaApiKey; fn from_env() -> Self { - let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set"); + let api_base = std::env::var("OLLAMA_API_BASE_URL") + .unwrap_or_else(|_| OLLAMA_API_BASE_URL.to_string()); + + let api_key: OllamaApiKey = std::env::var("OLLAMA_API_KEY") + .map(OllamaApiKey::from) + .unwrap_or_default(); Self::builder() - .api_key(Nothing) + .api_key(api_key) .base_url(&api_base) .build() - .unwrap() + .expect("failed to build Ollama client from environment") } - fn from_val(_: Self::Input) -> Self { - Self::builder().api_key(Nothing).build().unwrap() + fn from_val(api_key: Self::Input) -> Self { + Self::builder() + .api_key(api_key) + .build() + .expect("failed to build Ollama client") } }