Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ func (c *Client) CreateChatCompletion(
return nil, fmt.Errorf("request cannot be nil")
}

ctx, tcancel, err := getTimeoutContext(ctx, c.Timeout)
if err != nil {
return nil, err
}
defer tcancel()

req, err := utils.NewRequestBuilder(c.AuthToken).
SetBaseURL(c.BaseURL).
SetPath(c.Path).
Expand Down Expand Up @@ -51,6 +57,15 @@ func (c *Client) CreateChatCompletionStream(
ctx context.Context,
request *StreamChatCompletionRequest,
) (ChatCompletionStream, error) {
if request == nil {
return nil, fmt.Errorf("request cannot be nil")
}

ctx, tcancel, err := getTimeoutContext(ctx, c.Timeout)
if err != nil {
return nil, err
}
defer tcancel()

request.Stream = true
req, err := utils.NewRequestBuilder(c.AuthToken).
Expand Down
10 changes: 10 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type Client struct {
BaseURL string // The base URL for the API
Timeout time.Duration // The timeout for the current Client
Path string // The path for the API request. Defaults to "chat/completions"

HTTPClient HTTPDoer // The HTTP client to send the request and get the response
}

// NewClient creates a new client with an authentication token and an optional custom baseURL.
Expand Down Expand Up @@ -115,3 +117,11 @@ func WithPath(path string) Option {
return nil
}
}

// WithHTTPClient sets the http client for the API client.
func WithHTTPClient(httpclient HTTPDoer) Option {
return func(c *Client) error {
c.HTTPClient = httpclient
return nil
}
}
82 changes: 40 additions & 42 deletions requestHandler.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
package deepseek

import (
"errors"
"context"
"fmt"
"net/http"
"net/url"
"os"
"time"

"github.com/joho/godotenv"
)

// HandleTimeout gets the timeout duration from the DEEPSEEK_TIMEOUT environment variable.
//
// (xgfone): Do we need to export the function?
func HandleTimeout() (time.Duration, error) {
return handleTimeout()
}

func handleTimeout() (time.Duration, error) {
if err := godotenv.Load(); err != nil {
_ = err
}
Expand All @@ -28,59 +33,52 @@ func HandleTimeout() (time.Duration, error) {
return duration, nil
}

// checkTimeoutError checks if the error is a timeout error and returns a custom error message.
func checkTimeoutError(err error, timeout time.Duration) error {
var urlErr *url.Error
if errors.As(err, &urlErr) && urlErr.Timeout() {
return fmt.Errorf(
"request timed out after %s. You can increase the timeout by setting the DEEPSEEK_TIMEOUT environment variable. Original error: %w",
timeout,
err,
)
}
return nil
}

// HandleSendChatCompletionRequest sends a request to the DeepSeek API and returns the response.
func HandleSendChatCompletionRequest(c Client, req *http.Request) (*http.Response, error) {
// Check if c.Timeout is already set or not
timeout := c.Timeout
if timeout == 0 {
func getTimeoutContext(ctx context.Context, timeout time.Duration) (
context.Context,
context.CancelFunc,
error,
) {
if timeout <= 0 {
// Try to get timeout from environment variable
var err error
timeout, err = HandleTimeout()
timeout, err = handleTimeout()
if err != nil {
return nil, fmt.Errorf("error getting timeout: %w", err)
return nil, nil, fmt.Errorf("error getting timeout from environment: %w", err)
}
}
client := &http.Client{Timeout: timeout}
resp, err := client.Do(req)
if err != nil {
if timeoutErr := checkTimeoutError(err, timeout); timeoutErr != nil {
return nil, timeoutErr
}
return nil, fmt.Errorf("error sending request: %w", err)

var cancel context.CancelFunc
if timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, timeout)
} else {
cancel = func() {}
}

return resp, nil
return ctx, cancel, nil
}

// HandleSendChatCompletionRequest sends a request to the DeepSeek API and returns the response.
//
// (xgfone): Do we need to export this function?
func HandleSendChatCompletionRequest(c Client, req *http.Request) (*http.Response, error) {
return c.handleRequest(req)
}

// HandleNormalRequest sends a request to the DeepSeek API and returns the response.
//
// (xgfone): Do we need to export this function?
func HandleNormalRequest(c Client, req *http.Request) (*http.Response, error) {
// Check if c.Timeout is already set or not
timeout := c.Timeout
if timeout == 0 {
var err error
timeout, err = HandleTimeout()
if err != nil {
return nil, fmt.Errorf("error getting timeout: %w", err)
}
return c.handleRequest(req)
}

func (c *Client) handleRequest(req *http.Request) (*http.Response, error) {
client := c.HTTPClient
if client == nil {
client = http.DefaultClient
}
client := &http.Client{Timeout: timeout}

resp, err := client.Do(req)
if err != nil {
if timeoutErr := checkTimeoutError(err, timeout); timeoutErr != nil {
return nil, timeoutErr
}
return nil, fmt.Errorf("error sending request: %w", err)
}

Expand Down