Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions rig/rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ reqwest-middleware = { version = "0.5.1", optional = true, features = [
"http2",
] }

# Native-only dependency for OpenAI Responses websocket mode.
# Keeping it target-scoped avoids pulling a tokio socket transport into wasm builds.
[target.'cfg(not(target_family = "wasm"))'.dependencies]
tokio-tungstenite = { version = "0.23.1", features = ["rustls-tls-webpki-roots"] }

[dev-dependencies]
anyhow = { workspace = true }
assert_fs = { workspace = true }
Expand Down
68 changes: 68 additions & 0 deletions rig/rig-core/examples/openai_websocket_mode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use anyhow::Result;
use rig::client::{CompletionClient, ProviderClient};
use rig::completion::CompletionModel;
use rig::providers::openai;
use rig::providers::openai::responses_api::streaming::{ItemChunkKind, ResponseChunkKind};
use rig::providers::openai::responses_api::websocket::ResponsesWebSocketEvent;

#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt().init();

let client = openai::Client::from_env();
let model_name = openai::GPT_4O_MINI;
let model = client.completion_model(model_name);
let mut session = client.responses_websocket(model_name).await?;

let warmup_request = model
.completion_request("You will answer a follow-up question about websocket mode.")
.preamble("Be precise and concise.".to_string())
.build();

let warmup_id = session.warmup(warmup_request).await?;
println!("Warmup response id: {warmup_id}");

let request = model
.completion_request("Explain the benefit of websocket mode in one sentence.")
.build();

session.send(request).await?;

loop {
let event = session.next_event().await?;
match event {
ResponsesWebSocketEvent::Item(item) => {
if let ItemChunkKind::OutputTextDelta(delta) = item.data {
print!("{}", delta.delta);
}
}
ResponsesWebSocketEvent::Response(chunk) => {
println!("\nresponse event: {:?}", chunk.kind);
if matches!(
chunk.kind,
ResponseChunkKind::ResponseCompleted
| ResponseChunkKind::ResponseFailed
| ResponseChunkKind::ResponseIncomplete
) {
break;
}
}
ResponsesWebSocketEvent::Done(done) => {
println!("\nresponse.done id={:?}", done.response_id());
}
ResponsesWebSocketEvent::Error(error) => {
return Err(anyhow::anyhow!(error.to_string()));
}
}
}

let chained_request = model
.completion_request("Now restate that as three very short bullet points.")
.build();
let response = session.completion(chained_request).await?;

println!("Chained response: {:?}", response.choice);
session.close().await?;

Ok(())
}
26 changes: 26 additions & 0 deletions rig/rig-core/src/providers/openai/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,32 @@ where
}
}

#[cfg(not(target_family = "wasm"))]
impl Client<reqwest::Client> {
/// WebSocket mode currently uses a native `tokio-tungstenite` transport and does
/// not reuse custom `HttpClientExt` backends, so this API is only exposed for the
/// default `reqwest::Client` transport.
pub fn responses_websocket_builder(
&self,
model: impl Into<String>,
) -> super::responses_api::websocket::ResponsesWebSocketSessionBuilder {
super::responses_api::websocket::ResponsesWebSocketSessionBuilder::new(
self.completion_model(model),
)
}

/// This API is OpenAI-specific and only available on non-wasm targets in `rig-core`.
pub async fn responses_websocket(
&self,
model: impl Into<String>,
) -> Result<
super::responses_api::websocket::ResponsesWebSocketSession,
crate::completion::CompletionError,
> {
self.responses_websocket_builder(model).connect().await
}
}

impl<H> CompletionsClient<H>
where
H: HttpClientExt
Expand Down
2 changes: 2 additions & 0 deletions rig/rig-core/src/providers/openai/responses_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ use std::ops::Add;
use std::str::FromStr;

pub mod streaming;
#[cfg(not(target_family = "wasm"))]
pub mod websocket;

/// The completion request type for OpenAI's Response API: <https://platform.openai.com/docs/api-reference/responses/create>
/// Intended to be derived from [`crate::completion::request::CompletionRequest`].
Expand Down
108 changes: 104 additions & 4 deletions rig/rig-core/src/providers/openai/responses_api/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ pub struct ContentPartChunk {
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "type")]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPartChunkPart {
OutputText { text: String },
SummaryText { text: String },
Expand Down Expand Up @@ -227,7 +227,7 @@ pub struct SummaryTextChunk {
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "type")]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SummaryPartChunkPart {
SummaryText { text: String },
}
Expand Down Expand Up @@ -430,13 +430,13 @@ where

#[cfg(test)]
mod tests {
use super::reasoning_choices_from_done_item;
use super::{ItemChunkKind, StreamingCompletionChunk, reasoning_choices_from_done_item};
use crate::message::ReasoningContent;
use crate::providers::openai::responses_api::ReasoningSummary;
use crate::streaming::RawStreamingChoice;
use futures::StreamExt;
use rig::{client::CompletionClient, providers::openai, streaming::StreamingChat};
use serde_json;
use serde_json::{self, json};

use crate::{
completion::ToolDefinition,
Expand Down Expand Up @@ -522,6 +522,106 @@ mod tests {
));
}

#[test]
fn content_part_added_deserializes_snake_case_part_type() {
let chunk: StreamingCompletionChunk = serde_json::from_value(json!({
"type": "response.content_part.added",
"item_id": "msg_1",
"output_index": 0,
"content_index": 0,
"sequence_number": 3,
"part": {
"type": "output_text",
"text": "hello"
}
}))
.expect("content part event should deserialize");

assert!(matches!(
chunk,
StreamingCompletionChunk::Delta(chunk)
if matches!(
chunk.data,
ItemChunkKind::ContentPartAdded(_)
)
));
}

#[test]
fn content_part_done_deserializes_snake_case_part_type() {
let chunk: StreamingCompletionChunk = serde_json::from_value(json!({
"type": "response.content_part.done",
"item_id": "msg_1",
"output_index": 0,
"content_index": 0,
"sequence_number": 4,
"part": {
"type": "summary_text",
"text": "done"
}
}))
.expect("content part done event should deserialize");

assert!(matches!(
chunk,
StreamingCompletionChunk::Delta(chunk)
if matches!(
chunk.data,
ItemChunkKind::ContentPartDone(_)
)
));
}

#[test]
fn reasoning_summary_part_added_deserializes_snake_case_part_type() {
let chunk: StreamingCompletionChunk = serde_json::from_value(json!({
"type": "response.reasoning_summary_part.added",
"item_id": "rs_1",
"output_index": 0,
"summary_index": 0,
"sequence_number": 5,
"part": {
"type": "summary_text",
"text": "step 1"
}
}))
.expect("reasoning summary part event should deserialize");

assert!(matches!(
chunk,
StreamingCompletionChunk::Delta(chunk)
if matches!(
chunk.data,
ItemChunkKind::ReasoningSummaryPartAdded(_)
)
));
}

#[test]
fn reasoning_summary_part_done_deserializes_snake_case_part_type() {
let chunk: StreamingCompletionChunk = serde_json::from_value(json!({
"type": "response.reasoning_summary_part.done",
"item_id": "rs_1",
"output_index": 0,
"summary_index": 0,
"sequence_number": 6,
"part": {
"type": "summary_text",
"text": "step 2"
}
}))
.expect("reasoning summary part done event should deserialize");

assert!(matches!(
chunk,
StreamingCompletionChunk::Delta(chunk)
if matches!(
chunk.data,
ItemChunkKind::ReasoningSummaryPartDone(_)
)
));
}

// requires `derive` rig-core feature due to using tool macro
#[tokio::test]
#[ignore = "requires API key"]
Expand Down
Loading
Loading