diff --git a/README.md b/README.md index 16d4c52..5f0064b 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ The DeepSeek API SDK supports both asynchronous and synchronous usage patterns i ```rust use anyhow::Result; use clap::Parser; +use deepseek_api::request::MessageRequest; use deepseek_api::response::ModelType; use deepseek_api::{CompletionsRequestBuilder, DeepSeekClientBuilder, RequestBuilder}; use std::io::{stdin, stdout, Write}; @@ -83,9 +84,8 @@ async fn main() -> Result<()> { println!("models {:?}", models); } word => { - let resp = CompletionsRequestBuilder::new(vec![]) + let resp = CompletionsRequestBuilder::new(&[MessageRequest::user(word)]) .use_model(ModelType::DeepSeekChat) - .append_user_message(word) .do_request(&client) .await? .must_response(); @@ -134,9 +134,8 @@ fn main() -> Result<()> { .build()?; let mut history = vec![]; - let resp = CompletionsRequestBuilder::new(vec![]) + let resp = CompletionsRequestBuilder::new(&[MessageRequest::user("hello world")]) .use_model(ModelType::DeepSeekReasoner) - .append_user_message("hello world") .do_request(&client)? .must_response(); @@ -163,9 +162,7 @@ Use the function calling interface to define and invoke tools via the API. use anyhow::Result; use clap::Parser; use deepseek_api::request::MessageRequest; -use deepseek_api::request::{ - Function, ToolMessageRequest, ToolObject, ToolType, UserMessageRequest, -}; +use deepseek_api::request::{Function, ToolMessageRequest, ToolObject, ToolType}; use deepseek_api::response::FinishReason; use deepseek_api::{CompletionsRequestBuilder, DeepSeekClientBuilder, RequestBuilder}; use schemars::schema::SchemaObject; @@ -215,11 +212,10 @@ async fn main() -> Result<()> { }, }; - let mut messages = vec![MessageRequest::User(UserMessageRequest::new( - "How's the weather in Hangzhou?", - ))]; - let resp = CompletionsRequestBuilder::new(messages.clone()) - .tools(vec![tool_object.clone()]) + let tool_objects: Vec = vec![tool_object]; + let mut messages = vec![MessageRequest::user("How's the weather in Hangzhou?")]; + let resp = CompletionsRequestBuilder::new(&messages) + .tools(&tool_objects) .do_request(&client) .await? .must_response(); @@ -237,8 +233,8 @@ async fn main() -> Result<()> { } messages.push(MessageRequest::Tool(ToolMessageRequest::new("24℃", &id))); - let resp = CompletionsRequestBuilder::new(messages.clone()) - .tools(vec![tool_object.clone()]) + let resp = CompletionsRequestBuilder::new(&messages) + .tools(&tool_objects) .do_request(&client) .await? .must_response(); diff --git a/deepseek-api/src/async_impl/client.rs b/deepseek-api/src/async_impl/client.rs index db88a3a..88c5fc7 100644 --- a/deepseek-api/src/async_impl/client.rs +++ b/deepseek-api/src/async_impl/client.rs @@ -143,15 +143,14 @@ impl DeepSeekClient { /// ```no_run /// #[tokio::main] /// async fn main() { - /// use deepseek_api::{request::{MessageRequest, UserMessageRequest}, DeepSeekClientBuilder, CompletionsRequestBuilder}; + /// use deepseek_api::{request::MessageRequest, DeepSeekClientBuilder, CompletionsRequestBuilder}; /// use deepseek_api::response::ChatResponse; /// use futures_util::StreamExt; /// /// let api_key = "your_api_key".to_string(); /// let client = DeepSeekClientBuilder::new(api_key).build().unwrap(); - /// let request_builder = CompletionsRequestBuilder::new(vec![MessageRequest::User( - /// UserMessageRequest::new("Hello, DeepSeek!") - /// )]); + /// let msgs = &[MessageRequest::user("Hello, DeepSeek!")]; + /// let request_builder = CompletionsRequestBuilder::new(msgs); /// /// let response = client.send_completion_request(request_builder).await.unwrap(); /// match response { diff --git a/deepseek-api/src/request.rs b/deepseek-api/src/request.rs index c0c71c0..1074b0f 100644 --- a/deepseek-api/src/request.rs +++ b/deepseek-api/src/request.rs @@ -147,7 +147,7 @@ impl StreamOptions { /// Represents the temperature with a value between 0 and 2. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct Temperature(pub u32); +pub struct Temperature(pub f32); impl Temperature { /// Creates a new `Temperature` instance. @@ -159,8 +159,8 @@ impl Temperature { /// # Errors /// /// Returns an error if the value is not between 0 and 2. - pub fn new(v: u32) -> Result { - if v > 2 { + pub fn new(v: f32) -> Result { + if !(0.0..=2.0).contains(&v) { return Err(anyhow!("Temperature must be between 0 and 2.".to_string())); } Ok(Temperature(v)) @@ -170,7 +170,7 @@ impl Temperature { impl Default for Temperature { /// Returns the default value for `Temperature`, which is 1. fn default() -> Self { - Temperature(1) + Temperature(1.0) } } @@ -304,6 +304,30 @@ pub enum MessageRequest { } impl MessageRequest { + /// Creates a new `MessageRequest` instance for a user message. + /// + /// # Arguments + /// + /// * `content` - The content of the user message. + /// * `name` - An optional name for the user message. + pub fn user(content: &str) -> Self { + MessageRequest::User(UserMessageRequest { + content: content.to_string(), + name: None, + }) + } + + /// Creates a new `MessageRequest` instance for a system message. + /// + /// # Arguments + /// + /// * `content` - The content of the system message. + pub fn sys(content: &str) -> Self { + MessageRequest::System(SystemMessageRequest { + content: content.to_string(), + name: None, + }) + } pub fn get_content(&self) -> &str { match self { MessageRequest::System(req) => req.content.as_str(), diff --git a/deepseek-api/src/request_builder.rs b/deepseek-api/src/request_builder.rs index 33fc0c8..06c2971 100644 --- a/deepseek-api/src/request_builder.rs +++ b/deepseek-api/src/request_builder.rs @@ -1,14 +1,13 @@ -use serde::{de::DeserializeOwned, ser::SerializeStruct, Deserialize, Serialize, Serializer}; +use serde::{de::DeserializeOwned, ser::SerializeStruct, Serialize, Serializer}; use crate::{ request::{ FrequencyPenalty, MaxToken, MessageRequest, PresencePenalty, ResponseFormat, ResponseType, Stop, StreamOptions, Temperature, ToolChoice, ToolObject, TopLogprobs, TopP, - UserMessageRequest, }, response::{ - AssistantMessage, ChatCompletion, ChatCompletionStream, ChatResponse, JSONChoiceStream, - ModelType, TextChoiceStream, + ChatCompletion, ChatCompletionStream, ChatResponse, JSONChoiceStream, ModelType, + TextChoiceStream, }, DeepSeekClient, }; @@ -37,41 +36,28 @@ pub trait RequestBuilder: Sized + Send { } /// Represents a request for completions. -#[derive(Debug, Default, Clone, Deserialize)] -pub struct CompletionsRequest { - pub messages: Vec, +#[derive(Debug, Default, Clone)] +pub struct CompletionsRequest<'a> { + pub messages: &'a [MessageRequest], pub model: ModelType, - pub prompt: String, - #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub response_format: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub stop: Option, pub stream: bool, - #[serde(skip_serializing_if = "Option::is_none")] pub stream_options: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, - #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option<&'a [ToolObject]>, pub tool_choice: Option, // ignore when model is deepseek-reasoner - #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub presence_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub frequency_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub logprobs: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub top_logprobs: Option, } -impl Serialize for CompletionsRequest { +impl Serialize for CompletionsRequest<'_> { fn serialize(&self, serializer: S) -> Result where S: Serializer, @@ -80,23 +66,47 @@ impl Serialize for CompletionsRequest { state.serialize_field("messages", &self.messages)?; state.serialize_field("model", &self.model)?; - state.serialize_field("max_tokens", &self.max_tokens)?; - state.serialize_field("response_format", &self.response_format)?; - state.serialize_field("stop", &self.stop)?; + + if let Some(max_tokens) = &self.max_tokens { + state.serialize_field("max_tokens", max_tokens)?; + } + if let Some(response_format) = &self.response_format { + state.serialize_field("response_format", response_format)?; + } + if let Some(stop) = &self.stop { + state.serialize_field("stop", stop)?; + } state.serialize_field("stream", &self.stream)?; - state.serialize_field("stream_options", &self.stream_options)?; - state.serialize_field("tools", &self.tools)?; - state.serialize_field("tool_choice", &self.tool_choice)?; - state.serialize_field("prompt", &self.prompt)?; + if let Some(stream_options) = &self.stream_options { + state.serialize_field("stream_options", stream_options)?; + } + if let Some(tools) = &self.tools { + state.serialize_field("tools", tools)?; + } + if let Some(tool_choice) = &self.tool_choice { + state.serialize_field("tool_choice", tool_choice)?; + } // Skip these fields if model is DeepSeekReasoner if self.model != ModelType::DeepSeekReasoner { - state.serialize_field("temperature", &self.temperature)?; - state.serialize_field("top_p", &self.top_p)?; - state.serialize_field("presence_penalty", &self.presence_penalty)?; - state.serialize_field("frequency_penalty", &self.frequency_penalty)?; - state.serialize_field("logprobs", &self.logprobs)?; - state.serialize_field("top_logprobs", &self.top_logprobs)?; + if let Some(temperature) = &self.temperature { + state.serialize_field("temperature", temperature)?; + } + if let Some(top_p) = &self.top_p { + state.serialize_field("top_p", top_p)?; + } + if let Some(presence_penalty) = &self.presence_penalty { + state.serialize_field("presence_penalty", presence_penalty)?; + } + if let Some(frequency_penalty) = &self.frequency_penalty { + state.serialize_field("frequency_penalty", frequency_penalty)?; + } + if let Some(logprobs) = &self.logprobs { + state.serialize_field("logprobs", logprobs)?; + } + if let Some(top_logprobs) = &self.top_logprobs { + state.serialize_field("top_logprobs", top_logprobs)?; + } } state.end() @@ -104,10 +114,10 @@ impl Serialize for CompletionsRequest { } #[derive(Debug, Default)] -pub struct CompletionsRequestBuilder { +pub struct CompletionsRequestBuilder<'a> { //todo too many colone when use this type, improve it especially for message field beta: bool, - messages: Vec, + messages: &'a [MessageRequest], model: ModelType, stream: bool, @@ -116,9 +126,8 @@ pub struct CompletionsRequestBuilder { max_tokens: Option, response_format: Option, stop: Option, - tools: Option>, + tools: Option<&'a [ToolObject]>, tool_choice: Option, - prompt: String, temperature: Option, top_p: Option, presence_penalty: Option, @@ -127,12 +136,11 @@ pub struct CompletionsRequestBuilder { top_logprobs: Option, } -impl CompletionsRequestBuilder { - pub fn new(messages: Vec) -> Self { +impl<'a> CompletionsRequestBuilder<'a> { + pub fn new(messages: &'a [MessageRequest]) -> Self { Self { messages, model: ModelType::DeepSeekChat, - prompt: String::new(), ..Default::default() } } @@ -141,25 +149,6 @@ impl CompletionsRequestBuilder { self } - //https://api-docs.deepseek.com/guides/fim_completion - pub fn append_fim_message(self, _prompt: &str, _suffix: &str) -> Self { - todo!("Not enough detail in document") - } - - // https://api-docs.deepseek.com/zh-cn/guides/chat_prefix_completion - pub fn append_prefix_message(mut self, msg: &str) -> Self { - self.messages.push(MessageRequest::Assistant( - AssistantMessage::new(msg).set_prefix(msg), - )); - self - } - - pub fn append_user_message(mut self, msg: &str) -> Self { - self.messages - .push(MessageRequest::User(UserMessageRequest::new(msg))); - self - } - pub fn max_tokens(mut self, value: u32) -> Result { self.max_tokens = Some(MaxToken::new(value)?); Ok(self) @@ -190,7 +179,7 @@ impl CompletionsRequestBuilder { self } - pub fn tools(mut self, value: Vec) -> Self { + pub fn tools(mut self, value: &'a [ToolObject]) -> Self { self.tools = Some(value); self } @@ -200,12 +189,7 @@ impl CompletionsRequestBuilder { self } - pub fn prompt(mut self, value: String) -> Self { - self.prompt = value; - self - } - - pub fn temperature(mut self, value: u32) -> Result { + pub fn temperature(mut self, value: f32) -> Result { self.temperature = Some(Temperature::new(value)?); Ok(self) } @@ -236,8 +220,8 @@ impl CompletionsRequestBuilder { } } -impl RequestBuilder for CompletionsRequestBuilder { - type Request = CompletionsRequest; +impl<'a> RequestBuilder for CompletionsRequestBuilder<'a> { + type Request = CompletionsRequest<'a>; type Response = ChatCompletion; type Item = ChatCompletionStream; @@ -249,7 +233,7 @@ impl RequestBuilder for CompletionsRequestBuilder { self.stream } - fn build(self) -> CompletionsRequest { + fn build(self) -> CompletionsRequest<'a> { CompletionsRequest { messages: self.messages, model: self.model, @@ -260,7 +244,6 @@ impl RequestBuilder for CompletionsRequestBuilder { stream_options: self.stream_options, tools: self.tools, tool_choice: self.tool_choice, - prompt: self.prompt, temperature: self.temperature, top_p: self.top_p, presence_penalty: self.presence_penalty, @@ -277,6 +260,7 @@ pub struct FMICompletionsRequest { pub model: ModelType, pub prompt: String, pub echo: bool, + pub suffix: String, #[serde(skip_serializing_if = "Option::is_none")] pub frequency_penalty: Option, @@ -291,7 +275,7 @@ pub struct FMICompletionsRequest { pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")] pub stream_options: Option, - pub suffix: String, + #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -368,7 +352,7 @@ impl FMICompletionsRequestBuilder { self } - pub fn temperature(mut self, value: u32) -> Result { + pub fn temperature(mut self, value: f32) -> Result { self.temperature = Some(Temperature::new(value)?); Ok(self) } diff --git a/deepseek-api/src/sync_impl/client.rs b/deepseek-api/src/sync_impl/client.rs index 3706570..5f66eee 100644 --- a/deepseek-api/src/sync_impl/client.rs +++ b/deepseek-api/src/sync_impl/client.rs @@ -128,15 +128,14 @@ impl DeepSeekClient { /// # Example /// /// ```no_run - /// use deepseek_api::{request::{MessageRequest, UserMessageRequest}, DeepSeekClientBuilder, CompletionsRequestBuilder}; + /// use deepseek_api::{request::MessageRequest, DeepSeekClientBuilder, CompletionsRequestBuilder}; /// use deepseek_api::response::ChatResponse; /// use futures_util::StreamExt; /// /// let api_key = "your_api_key".to_string(); /// let client = DeepSeekClientBuilder::new(api_key).build().unwrap(); - /// let request_builder = CompletionsRequestBuilder::new(vec![MessageRequest::User( - /// UserMessageRequest::new("Hello, DeepSeek!") - /// )]); + /// let msgs = &[MessageRequest::user("Hello, DeepSeek!")]; + /// let request_builder = CompletionsRequestBuilder::new(msgs); /// /// let response = client.send_completion_request(request_builder).unwrap(); /// match response { diff --git a/ds-cli/src/main.rs b/ds-cli/src/main.rs index 9d629f2..fb84aa8 100644 --- a/ds-cli/src/main.rs +++ b/ds-cli/src/main.rs @@ -55,7 +55,7 @@ fn main() -> Result<()> { thread::spawn(move || loop { //request thread let msg: String = req_receiver.recv().unwrap(); - let (req_msgs, btn_state) = { + let resp = { let mut req_state = req_state.write().unwrap(); if req_state.btn_state.send_system_msg { req_state @@ -74,20 +74,19 @@ fn main() -> Result<()> { content: Some(msg.clone()), reasoning_content: None, }); - (req_state.history.clone(), req_state.btn_state.clone()) - }; - let model = if btn_state.use_reasoning_model { - ModelType::DeepSeekReasoner - } else { - ModelType::DeepSeekChat + let model = if req_state.btn_state.use_reasoning_model { + ModelType::DeepSeekReasoner + } else { + ModelType::DeepSeekChat + }; + CompletionsRequestBuilder::new(&req_state.history) + .stream(true) + .use_model(model) + .do_request(&client) + .unwrap() + .must_stream() }; - let resp = CompletionsRequestBuilder::new(req_msgs.clone()) - .stream(true) - .use_model(model) - .do_request(&client) - .unwrap() - .must_stream(); let mut content_buf = String::new(); let mut reasoning_content_buf = String::new(); diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index 55b6a10..21cf4b9 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -1,5 +1,6 @@ use anyhow::Result; use clap::Parser; +use deepseek_api::request::MessageRequest; use deepseek_api::response::ModelType; use deepseek_api::{CompletionsRequestBuilder, DeepSeekClientBuilder, RequestBuilder}; use std::io::{stdin, stdout, Write}; @@ -42,9 +43,8 @@ async fn main() -> Result<()> { println!("models {:?}", models); } word => { - let resp = CompletionsRequestBuilder::new(vec![]) + let resp = CompletionsRequestBuilder::new(&[MessageRequest::user(word)]) .use_model(ModelType::DeepSeekChat) - .append_user_message(word) .do_request(&client) .await? .must_response(); diff --git a/examples/deep-think/src/main.rs b/examples/deep-think/src/main.rs index 70dfee6..8f4be76 100644 --- a/examples/deep-think/src/main.rs +++ b/examples/deep-think/src/main.rs @@ -1,8 +1,7 @@ use anyhow::Result; use clap::Parser; use deepseek_api::{ - CompletionsRequestBuilder, DeepSeekClientBuilder, RequestBuilder, - request::{MessageRequest, UserMessageRequest}, + CompletionsRequestBuilder, DeepSeekClientBuilder, RequestBuilder, request::MessageRequest, response::ModelType, }; use tokio_stream::StreamExt; @@ -19,14 +18,13 @@ async fn main() -> Result<()> { let args = Args::parse(); let client = DeepSeekClientBuilder::new(args.api_key.clone()).build()?; - let mut stream = CompletionsRequestBuilder::new(vec![MessageRequest::User( - UserMessageRequest::new("how to get to beijing"), - )]) - .use_model(ModelType::DeepSeekReasoner) - .stream(true) - .do_request(&client) - .await? - .must_stream(); + let mut stream = + CompletionsRequestBuilder::new(&[MessageRequest::user("how to get to beijing")]) + .use_model(ModelType::DeepSeekReasoner) + .stream(true) + .do_request(&client) + .await? + .must_stream(); while let Some(item) = stream.next().await { println!("resp: {:?}", item); } diff --git a/examples/function-call/src/main.rs b/examples/function-call/src/main.rs index cf661cb..8d3364a 100644 --- a/examples/function-call/src/main.rs +++ b/examples/function-call/src/main.rs @@ -1,9 +1,7 @@ use anyhow::Result; use clap::Parser; use deepseek_api::request::MessageRequest; -use deepseek_api::request::{ - Function, ToolMessageRequest, ToolObject, ToolType, UserMessageRequest, -}; +use deepseek_api::request::{Function, ToolMessageRequest, ToolObject, ToolType}; use deepseek_api::response::FinishReason; use deepseek_api::{CompletionsRequestBuilder, DeepSeekClientBuilder, RequestBuilder}; use schemars::schema::SchemaObject; @@ -53,11 +51,10 @@ async fn main() -> Result<()> { }, }; - let mut messages = vec![MessageRequest::User(UserMessageRequest::new( - "How's the weather in Hangzhou?", - ))]; - let resp = CompletionsRequestBuilder::new(messages.clone()) - .tools(vec![tool_object.clone()]) + let tool_objects: Vec = vec![tool_object]; + let mut messages = vec![MessageRequest::user("How's the weather in Hangzhou?")]; + let resp = CompletionsRequestBuilder::new(&messages) + .tools(&tool_objects) .do_request(&client) .await? .must_response(); @@ -75,8 +72,8 @@ async fn main() -> Result<()> { } messages.push(MessageRequest::Tool(ToolMessageRequest::new("24℃", &id))); - let resp = CompletionsRequestBuilder::new(messages.clone()) - .tools(vec![tool_object.clone()]) + let resp = CompletionsRequestBuilder::new(&messages) + .tools(&tool_objects) .do_request(&client) .await? .must_response(); diff --git a/examples/sync-basic/src/main.rs b/examples/sync-basic/src/main.rs index bc67e78..0099dcd 100644 --- a/examples/sync-basic/src/main.rs +++ b/examples/sync-basic/src/main.rs @@ -19,9 +19,8 @@ fn main() -> Result<()> { .build()?; let mut history = vec![]; - let resp = CompletionsRequestBuilder::new(vec![]) + let resp = CompletionsRequestBuilder::new(&[MessageRequest::user("hello world")]) .use_model(ModelType::DeepSeekReasoner) - .append_user_message("hello world") .do_request(&client)? .must_response();