diff --git a/client.go b/client.go index 8c9a8a4..aece0fc 100644 --- a/client.go +++ b/client.go @@ -17,6 +17,12 @@ func (c *Client) CreateChatCompletion( return nil, fmt.Errorf("request cannot be nil") } + ctx, tcancel, err := getTimeoutContext(ctx, c.Timeout) + if err != nil { + return nil, err + } + defer tcancel() + req, err := utils.NewRequestBuilder(c.AuthToken). SetBaseURL(c.BaseURL). SetPath(c.Path). @@ -51,6 +57,15 @@ func (c *Client) CreateChatCompletionStream( ctx context.Context, request *StreamChatCompletionRequest, ) (ChatCompletionStream, error) { + if request == nil { + return nil, fmt.Errorf("request cannot be nil") + } + + ctx, tcancel, err := getTimeoutContext(ctx, c.Timeout) + if err != nil { + return nil, err + } + defer tcancel() request.Stream = true req, err := utils.NewRequestBuilder(c.AuthToken). diff --git a/config.go b/config.go index 0100f0d..f0a452e 100644 --- a/config.go +++ b/config.go @@ -21,6 +21,8 @@ type Client struct { BaseURL string // The base URL for the API Timeout time.Duration // The timeout for the current Client Path string // The path for the API request. Defaults to "chat/completions" + + HTTPClient HTTPDoer // The HTTP client to send the request and get the response } // NewClient creates a new client with an authentication token and an optional custom baseURL. @@ -115,3 +117,11 @@ func WithPath(path string) Option { return nil } } + +// WithHTTPClient sets the http client for the API client. +func WithHTTPClient(httpclient HTTPDoer) Option { + return func(c *Client) error { + c.HTTPClient = httpclient + return nil + } +} diff --git a/requestHandler.go b/requestHandler.go index 803fc9c..7b69923 100644 --- a/requestHandler.go +++ b/requestHandler.go @@ -1,10 +1,9 @@ package deepseek import ( - "errors" + "context" "fmt" "net/http" - "net/url" "os" "time" @@ -12,7 +11,13 @@ import ( ) // HandleTimeout gets the timeout duration from the DEEPSEEK_TIMEOUT environment variable. +// +// (xgfone): Do we need to export the function? func HandleTimeout() (time.Duration, error) { + return handleTimeout() +} + +func handleTimeout() (time.Duration, error) { if err := godotenv.Load(); err != nil { _ = err } @@ -28,59 +33,52 @@ func HandleTimeout() (time.Duration, error) { return duration, nil } -// checkTimeoutError checks if the error is a timeout error and returns a custom error message. -func checkTimeoutError(err error, timeout time.Duration) error { - var urlErr *url.Error - if errors.As(err, &urlErr) && urlErr.Timeout() { - return fmt.Errorf( - "request timed out after %s. You can increase the timeout by setting the DEEPSEEK_TIMEOUT environment variable. Original error: %w", - timeout, - err, - ) - } - return nil -} - -// HandleSendChatCompletionRequest sends a request to the DeepSeek API and returns the response. -func HandleSendChatCompletionRequest(c Client, req *http.Request) (*http.Response, error) { - // Check if c.Timeout is already set or not - timeout := c.Timeout - if timeout == 0 { +func getTimeoutContext(ctx context.Context, timeout time.Duration) ( + context.Context, + context.CancelFunc, + error, +) { + if timeout <= 0 { + // Try to get timeout from environment variable var err error - timeout, err = HandleTimeout() + timeout, err = handleTimeout() if err != nil { - return nil, fmt.Errorf("error getting timeout: %w", err) + return nil, nil, fmt.Errorf("error getting timeout from environment: %w", err) } } - client := &http.Client{Timeout: timeout} - resp, err := client.Do(req) - if err != nil { - if timeoutErr := checkTimeoutError(err, timeout); timeoutErr != nil { - return nil, timeoutErr - } - return nil, fmt.Errorf("error sending request: %w", err) + + var cancel context.CancelFunc + if timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, timeout) + } else { + cancel = func() {} } - return resp, nil + return ctx, cancel, nil +} + +// HandleSendChatCompletionRequest sends a request to the DeepSeek API and returns the response. +// +// (xgfone): Do we need to export this function? +func HandleSendChatCompletionRequest(c Client, req *http.Request) (*http.Response, error) { + return c.handleRequest(req) } // HandleNormalRequest sends a request to the DeepSeek API and returns the response. +// +// (xgfone): Do we need to export this function? func HandleNormalRequest(c Client, req *http.Request) (*http.Response, error) { - // Check if c.Timeout is already set or not - timeout := c.Timeout - if timeout == 0 { - var err error - timeout, err = HandleTimeout() - if err != nil { - return nil, fmt.Errorf("error getting timeout: %w", err) - } + return c.handleRequest(req) +} + +func (c *Client) handleRequest(req *http.Request) (*http.Response, error) { + client := c.HTTPClient + if client == nil { + client = http.DefaultClient } - client := &http.Client{Timeout: timeout} + resp, err := client.Do(req) if err != nil { - if timeoutErr := checkTimeoutError(err, timeout); timeoutErr != nil { - return nil, timeoutErr - } return nil, fmt.Errorf("error sending request: %w", err) }