Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 65 additions & 11 deletions rig/rig-core/src/providers/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@
//! use rig::completion::Prompt;
//! use rig::providers::ollama;
//!
//! // Create a new Ollama client (defaults to http://localhost:11434)
//! // In the case of ollama, no API key is necessary, so we use the `Nothing` struct
//! // Create a new Ollama client (defaults to http://localhost:11434, no auth)
//! let client = ollama::Client::new(Nothing).unwrap();
//!
//! // Or connect to a remote/proxied Ollama instance with authentication
//! let client = ollama::Client::builder()
//! .api_key("my-secret-key")
//! .base_url("http://remote-ollama:11434")
//! .build()
//! .unwrap();
//!
//! // Create an agent with a preamble
//! let comedian_agent = client
//! .agent("qwen2.5:14b")
Expand All @@ -32,7 +38,8 @@
//! let extractor = client.extractor::<serde_json::Value>("llama3.2").build();
//! ```
use crate::client::{
self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
ProviderClient,
};
use crate::completion::{GetTokenUsage, Usage};
use crate::http_client::{self, HttpClientExt};
Expand All @@ -58,6 +65,45 @@ use tracing_futures::Instrument;

const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";

/// Optional API key for Ollama. By default Ollama requires no authentication,
/// but proxied or secured deployments may require a Bearer token.
#[derive(Debug, Default, Clone)]
pub struct OllamaApiKey(Option<String>);

impl ApiKey for OllamaApiKey {
fn into_header(
self,
) -> Option<http_client::Result<(http::header::HeaderName, http::header::HeaderValue)>> {
self.0.map(http_client::make_auth_header)
}
}

impl From<Nothing> for OllamaApiKey {
fn from(_: Nothing) -> Self {
Self(None)
}
}

impl From<String> for OllamaApiKey {
fn from(key: String) -> Self {
if key.is_empty() {
Self(None)
} else {
Self(Some(key))
}
}
}

impl From<&str> for OllamaApiKey {
fn from(key: &str) -> Self {
if key.is_empty() {
Self(None)
} else {
Self(Some(key.to_owned()))
}
}
}

#[derive(Debug, Default, Clone, Copy)]
pub struct OllamaExt;

Expand Down Expand Up @@ -88,7 +134,7 @@ impl ProviderBuilder for OllamaBuilder {
= OllamaExt
where
H: HttpClientExt;
type ApiKey = Nothing;
type ApiKey = OllamaApiKey;

const BASE_URL: &'static str = OLLAMA_API_BASE_URL;

Expand All @@ -103,23 +149,31 @@ impl ProviderBuilder for OllamaBuilder {
}

pub type Client<H = reqwest::Client> = client::Client<OllamaExt, H>;
pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, Nothing, H>;
pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<OllamaBuilder, OllamaApiKey, H>;

impl ProviderClient for Client {
type Input = Nothing;
type Input = OllamaApiKey;

fn from_env() -> Self {
let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set");
let api_base = std::env::var("OLLAMA_API_BASE_URL")
.unwrap_or_else(|_| OLLAMA_API_BASE_URL.to_string());

let api_key: OllamaApiKey = std::env::var("OLLAMA_API_KEY")
.map(OllamaApiKey::from)
.unwrap_or_default();

Self::builder()
.api_key(Nothing)
.api_key(api_key)
.base_url(&api_base)
.build()
.unwrap()
.expect("failed to build Ollama client from environment")
}

fn from_val(_: Self::Input) -> Self {
Self::builder().api_key(Nothing).build().unwrap()
fn from_val(api_key: Self::Input) -> Self {
Self::builder()
.api_key(api_key)
.build()
.expect("failed to build Ollama client")
}
}

Expand Down