diff --git a/circuit/circuits.go b/circuit/circuits.go deleted file mode 100644 index 3c888d8..0000000 --- a/circuit/circuits.go +++ /dev/null @@ -1,129 +0,0 @@ -package http - -import ( - "fmt" - "time" - - "github.com/afex/hystrix-go/hystrix" -) - -type ( - CircuitFunc func() error - CircuitErrorFilter func(error) (bool, error) -) - -type CircuitConfig struct { - Name string - Timeout int - MaxConcurrentRequests int - ErrorPercentThreshold int - RequestVolumeThreshold int - SleepWindow int - Commands []string -} - -type Circuit struct { - config CircuitConfig -} - -func NewCircuit(c CircuitConfig) *Circuit { - hystrixConfig := hystrix.CommandConfig{ - Timeout: c.Timeout, - MaxConcurrentRequests: c.MaxConcurrentRequests, - ErrorPercentThreshold: c.ErrorPercentThreshold, - RequestVolumeThreshold: c.RequestVolumeThreshold, - SleepWindow: c.SleepWindow, - } - - for _, command := range c.Commands { - hystrix.ConfigureCommand(fmt.Sprintf("%s:%s", c.Name, command), hystrixConfig) - } - - return &Circuit{ - config: c, - } -} - -func (c *Circuit) Do(command string, fu CircuitFunc, fallback func(error) error, fi ...CircuitErrorFilter) error { - var e error - var ok bool - - function := func() error { - err := fu() - - for _, filter := range fi { - if ok, e = filter(err); ok { - return err - } - } - - if len(fi) > 0 { - return nil - } - - return err - } - - hystrixErr := hystrix.Do(fmt.Sprintf("%s:%s", c.config.Name, command), function, fallback) - - if hystrixErr != nil { - return hystrixErr - } - - if e != nil { - return e - } - - return nil -} - -func (c *Circuit) DoR( - command string, - fu CircuitFunc, - fallback func(error) error, - retry int, - delay time.Duration, - fi ...CircuitErrorFilter, -) error { - var e error - var ok bool - - filter := func() error { - err := fu() - - for _, filter := range fi { - if ok, e = filter(err); ok { - return err - } - } - - if len(fi) > 0 { - return nil - } - - return err - } - - function := func() error { - err := filter() - - for c := 0; c < retry && err != nil; c++ { - time.Sleep(delay) - err = filter() - } - - return err - } - - hystrixErr := hystrix.Do(fmt.Sprintf("%s:%s", c.config.Name, command), function, fallback) - - if hystrixErr != nil { - return hystrixErr - } - - if e != nil { - return e - } - - return nil -} diff --git a/error/error.go b/error/error.go new file mode 100644 index 0000000..a64a5e9 --- /dev/null +++ b/error/error.go @@ -0,0 +1,58 @@ +package error + +import ( + "fmt" + "strings" +) + +type ChakiError struct { + ClientName string + StatusCode int + RawBody []byte + ParsedBody interface{} +} + +func (e *ChakiError) Error() string { + msg := fmt.Sprintf("Error on client %s (Status %d)", e.ClientName, e.StatusCode) + if details := e.extractErrorDetails(); details != "" { + msg += ": " + details + } + + return msg +} + +func (e *ChakiError) Status() int { + return e.StatusCode +} + +type RandomError interface { + Status() int +} + +func (e *ChakiError) extractErrorDetails() string { + var details []string + + var extract func(interface{}) + extract = func(v interface{}) { + switch value := v.(type) { + case string: + details = append(details, strings.TrimSpace(value)) + case map[string]interface{}: + for _, v := range value { + extract(v) + } + case []interface{}: + for _, v := range value { + extract(v) + } + } + } + + extract(e.ParsedBody) + + if len(details) == 0 && len(e.RawBody) > 0 { + return strings.TrimSpace(string(e.RawBody)) + } + + return strings.Join(details, "; ") +} diff --git a/example/client-server-with-otel/client/client.go b/example/client-server-with-otel/client/client.go deleted file mode 100644 index dcc2fb4..0000000 --- a/example/client-server-with-otel/client/client.go +++ /dev/null @@ -1,76 +0,0 @@ -package main - -import ( - "context" - "fmt" - - "github.com/Trendyol/chaki/modules/client" - "github.com/Trendyol/chaki/modules/server/response" -) - -type exampleClient struct { - *client.Base -} - -func newClient(f *client.Factory) *exampleClient { - return &exampleClient{ - Base: f.Get("example-client"), - } -} - -func (cl *exampleClient) SendHello(ctx context.Context) (string, error) { - resp := &response.Response[string]{} - if _, err := cl.Request(ctx).SetResult(resp).Get("/hello"); err != nil { - return "", err - } - - return resp.Data, nil -} - -func (cl *exampleClient) sendGreetWithQuery(ctx context.Context, req GreetWithQueryRequest) (string, error) { - resp := &response.Response[string]{} - - params := map[string]string{ - "text": req.Text, - "repeatTimes": fmt.Sprintf("%d", req.RepeatTimes), - } - - if _, err := cl.Request(ctx). - SetResult(resp). - SetQueryParams(params). - Get("/hello/query"); err != nil { - return "", err - } - return resp.Data, nil -} - -func (cl *exampleClient) sendGreetWithParam(ctx context.Context, req GreetWithParamRequest) (string, error) { - resp := &response.Response[string]{} - - url := fmt.Sprintf("/hello/param/%s", req.Text) - params := map[string]string{ - "repeatTimes": fmt.Sprintf("%d", req.RepeatTimes), - } - - if _, err := cl.Request(ctx). - SetResult(resp). - SetQueryParams(params). - Get(url); err != nil { - return "", err - } - - return resp.Data, nil -} - -func (cl *exampleClient) sendGreetWithBody(ctx context.Context, req GreetWithBodyRequest) (string, error) { - resp := &response.Response[string]{} - - if _, err := cl.Request(ctx). - SetResult(resp). - SetBody(req). - Post("/hello/body"); err != nil { - return "", err - } - - return resp.Data, nil -} diff --git a/example/client-server-with-otel/client/main.go b/example/client-server-with-otel/client/main.go deleted file mode 100644 index e4af20b..0000000 --- a/example/client-server-with-otel/client/main.go +++ /dev/null @@ -1,80 +0,0 @@ -package main - -import ( - "context" - - "github.com/Trendyol/chaki" - "github.com/Trendyol/chaki/logger" - "github.com/Trendyol/chaki/modules/client" - "github.com/Trendyol/chaki/modules/otel" - otelclient "github.com/Trendyol/chaki/modules/otel/client" - otelserver "github.com/Trendyol/chaki/modules/otel/server" - - "github.com/Trendyol/chaki/modules/server" - "github.com/Trendyol/chaki/modules/server/controller" - "github.com/Trendyol/chaki/modules/server/route" -) - -func main() { - app := chaki.New() - - app.Use( - client.Module(), - server.Module(), - otel.Module( - otelclient.WithClient(), - otelserver.WithServer(), - ), - ) - - app.Provide( - newClient, - NewController, - ) - - if err := app.Start(); err != nil { - logger.Fatal(err) - } -} - -type serverController struct { - *controller.Base - cl *exampleClient -} - -func NewController(cl *exampleClient) controller.Controller { - return &serverController{ - Base: controller.New("server-controller").SetPrefix("/hello"), - cl: cl, - } -} - -func (ct *serverController) hello(ctx context.Context, _ any) (string, error) { - logger.From(ctx).Info("hello") - resp, err := ct.cl.SendHello(ctx) - if err != nil { - return "", err - } - return resp, nil -} - -func (ct *serverController) greetWithBody(ctx context.Context, req GreetWithBodyRequest) (string, error) { - return ct.cl.sendGreetWithBody(ctx, req) -} - -func (ct *serverController) greetWithQuery(ctx context.Context, req GreetWithQueryRequest) (string, error) { - return ct.cl.sendGreetWithQuery(ctx, req) -} - -func (ct *serverController) greetWithParam(ctx context.Context, req GreetWithParamRequest) (string, error) { - return ct.cl.sendGreetWithParam(ctx, req) -} - -func (ct *serverController) Routes() []route.Route { - return []route.Route{ - route.Get("", ct.hello), - route.Get("/query", ct.greetWithQuery).Name("Greet with query"), - route.Get("/param/:text", ct.greetWithParam).Name("Greet with param"), - route.Post("/body", ct.greetWithBody).Name("Greet with body"), - } -} diff --git a/example/client-server-with-otel/client/request.go b/example/client-server-with-otel/client/request.go deleted file mode 100644 index 8384481..0000000 --- a/example/client-server-with-otel/client/request.go +++ /dev/null @@ -1,16 +0,0 @@ -package main - -type GreetWithBodyRequest struct { - Text string `json:"text" validate:"required,min=3,max=100"` - RepeatTimes int `json:"repeatTimes" validate:"required,gte=1,lte=5"` -} - -type GreetWithQueryRequest struct { - Text string `query:"text" validate:"required,min=3,max=100"` - RepeatTimes int `query:"repeatTimes" validate:"required,gte=1,lte=5"` -} - -type GreetWithParamRequest struct { - Text string `param:"text" validate:"required,min=3,max=100"` - RepeatTimes int `query:"repeatTimes" validate:"required,gte=1,lte=5"` -} diff --git a/example/client-server-with-otel/client/resources/configs/application.yaml b/example/client-server-with-otel/client/resources/configs/application.yaml deleted file mode 100644 index 45d669f..0000000 --- a/example/client-server-with-otel/client/resources/configs/application.yaml +++ /dev/null @@ -1,6 +0,0 @@ -server: - addr: ":8081" - -client: - example-client: - baseUrl: "http://localhost:8082" diff --git a/example/client-server-with-otel/server/main.go b/example/client-server-with-otel/server/main.go deleted file mode 100644 index ae582fd..0000000 --- a/example/client-server-with-otel/server/main.go +++ /dev/null @@ -1,75 +0,0 @@ -package main - -import ( - "context" - - "github.com/Trendyol/chaki" - "github.com/Trendyol/chaki/logger" - "github.com/Trendyol/chaki/modules/otel" - otelclient "github.com/Trendyol/chaki/modules/otel/client" - otelserver "github.com/Trendyol/chaki/modules/otel/server" - "github.com/Trendyol/chaki/modules/server" - "github.com/Trendyol/chaki/modules/server/controller" - "github.com/Trendyol/chaki/modules/server/route" -) - -func main() { - app := chaki.New() - - app.Use( - server.Module(), - otel.Module( - otelclient.WithClient(), - otelserver.WithServer(), - ), - ) - - app.Provide( - NewController, - ) - - if err := app.Start(); err != nil { - logger.Fatal(err) - } -} - -type serverController struct { - *controller.Base -} - -func NewController() controller.Controller { - return &serverController{ - Base: controller.New("server-controller").SetPrefix("/hello"), - } -} - -func (ct *serverController) hello(ctx context.Context, _ any) (string, error) { - logger.From(ctx).Info("hello from server") - return "Hi From Server", nil -} - -func (ct *serverController) Routes() []route.Route { - return []route.Route{ - route.Get("/", ct.hello), - route.Get("/greet", ct.greetHandler).Name("Greet Route"), - route.Get("/query", ct.greetWithQuery).Name("Greet with query"), - route.Get("/param/:text", ct.greetWithParam).Name("Greet with param"), - route.Post("/body", ct.greetWithBody).Name("Greet with body"), - } -} - -func (ct *serverController) greetHandler(_ context.Context, _ struct{}) (string, error) { - return "Greetings!", nil -} - -func (ct *serverController) greetWithBody(_ context.Context, req GreetWithBodyRequest) (string, error) { - return req.Text, nil -} - -func (ct *serverController) greetWithQuery(_ context.Context, req GreetWithQueryRequest) (string, error) { - return req.Text, nil -} - -func (ct *serverController) greetWithParam(_ context.Context, req GreetWithParamRequest) (string, error) { - return req.Text, nil -} diff --git a/example/client-server-with-otel/server/request.go b/example/client-server-with-otel/server/request.go deleted file mode 100644 index 8384481..0000000 --- a/example/client-server-with-otel/server/request.go +++ /dev/null @@ -1,16 +0,0 @@ -package main - -type GreetWithBodyRequest struct { - Text string `json:"text" validate:"required,min=3,max=100"` - RepeatTimes int `json:"repeatTimes" validate:"required,gte=1,lte=5"` -} - -type GreetWithQueryRequest struct { - Text string `query:"text" validate:"required,min=3,max=100"` - RepeatTimes int `query:"repeatTimes" validate:"required,gte=1,lte=5"` -} - -type GreetWithParamRequest struct { - Text string `param:"text" validate:"required,min=3,max=100"` - RepeatTimes int `query:"repeatTimes" validate:"required,gte=1,lte=5"` -} diff --git a/example/client-server-with-otel/server/resources/configs/secrets.yaml b/example/client-server-with-otel/server/resources/configs/secrets.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/example/client-server/client/client.go b/example/client-server/client/client.go index 578ee58..1a3a2a7 100644 --- a/example/client-server/client/client.go +++ b/example/client-server/client/client.go @@ -2,77 +2,115 @@ package main import ( "context" - "fmt" - "github.com/Trendyol/chaki/modules/client" "github.com/Trendyol/chaki/modules/server/response" + "github.com/go-resty/resty/v2" +) + +const ( + errorEndpoint = "{category}/error" + notFoundEndpoint = "{category}/not-found" + successfulEndpoint = "{category}" ) -type exampleClient struct { +type CustomClient struct { *client.Base } -func newClient(f *client.Factory) *exampleClient { - return &exampleClient{ - Base: f.Get("example-client", - client.WithErrDecoder(customErrorDecoder), - client.WithDriverWrappers(HeaderWrapper())), +type UltimateRequestBody struct { + Message string `json:"message"` +} + +func NewCustomClient(f *client.Factory) *CustomClient { + + return &CustomClient{ + Base: f.Get("custom-client", client.WithErrDecoder(customErrorDecoder)), } } -func (cl *exampleClient) SendHello(ctx context.Context) (string, error) { - resp := &response.Response[string]{} - if _, err := cl.Request(ctx).SetResult(resp).Get("/hello"); err != nil { - return "", err +func customErrorDecoder(_ context.Context, res *resty.Response) error { + if res.IsSuccess() { + return nil } - return resp.Data, nil + if res.StatusCode() == 404 { + return client.GenericClientError{ParsedBody: "not found from custom err decoder", StatusCode: res.StatusCode()} + } + + return client.GenericClientError{ParsedBody: "generic error se the code :)", StatusCode: res.StatusCode()} } -func (cl *exampleClient) sendGreetWithQuery(ctx context.Context, req GreetWithQueryRequest) (string, error) { +func (c *CustomClient) SuccessfulEndpoint(ctx context.Context, req *UltimateRequest) (*response.Response[string], error) { resp := &response.Response[string]{} - params := map[string]string{ - "text": req.Text, - "repeatTimes": fmt.Sprintf("%d", req.RepeatTimes), - } - - if _, err := cl.Request(ctx). + if _, err := c.RequestWithCommand(ctx, "commandpostsuccess"). + SetPathParam("category", req.Category). + SetBody(UltimateRequestBody{Message: req.Message}). + SetQueryParam("lang", req.Language). SetResult(resp). - SetQueryParams(params). - Get("/hello/query"); err != nil { - return "", err + Post(successfulEndpoint); err != nil { + return nil, err } - return resp.Data, nil + + return resp, nil } -func (cl *exampleClient) sendGreetWithParam(ctx context.Context, req GreetWithParamRequest) (string, error) { +func (c *CustomClient) GetNotFoundErr(ctx context.Context, req *UltimateRequest) (*response.Response[string], error) { resp := &response.Response[string]{} - url := fmt.Sprintf("/hello/param/%s", req.Text) - params := map[string]string{ - "repeatTimes": fmt.Sprintf("%d", req.RepeatTimes), + if _, err := c.RequestWithCommand(ctx, "commandposterror"). + SetPathParam("category", req.Category). + SetBody(UltimateRequestBody{Message: req.Message}). + SetQueryParam("lang", req.Language). + SetResult(resp). + Post(notFoundEndpoint); err != nil { + return nil, err } - if _, err := cl.Request(ctx). + return resp, nil +} + +func (c *CustomClient) GetError(ctx context.Context, req *UltimateRequest) (*response.Response[string], error) { + resp := &response.Response[string]{} + + if _, err := c.RequestWithCommand(ctx, "commandposterror"). + SetPathParam("category", req.Category). + SetBody(UltimateRequestBody{Message: req.Message}). + SetQueryParam("lang", req.Language). SetResult(resp). - SetQueryParams(params). - Get(url); err != nil { - return "", err + Post(errorEndpoint); err != nil { + return nil, err } - return resp.Data, nil + return resp, nil } -func (cl *exampleClient) sendGreetWithBody(ctx context.Context, req GreetWithBodyRequest) (string, error) { +func (c *CustomClient) GetErrorWithFallback(ctx context.Context, req *UltimateRequest) (*response.Response[string], error) { resp := &response.Response[string]{} - if _, err := cl.Request(ctx). + ctx = client.SetFallbackFunc(ctx, func(ctx context.Context, err error) (any, error) { + + // Any fallback mechanism can apply here. The tricky part here is, + // If you use .SetResult(res) method from resty.Request, + // You should return the same here. If not, you can handle + // The fallback response as you wish by using httpRes.Result() + return &response.Response[string]{ + Data: "this response is from fallback", + }, nil + }) + + httpRes, err := c.RequestWithCommand(ctx, "commandpostfallback"). + SetPathParam("category", req.Category). + SetBody(UltimateRequestBody{Message: req.Message}). + SetQueryParam("lang", "en"). SetResult(resp). - SetBody(req). - Post("/hello/body"); err != nil { - return "", err + Post(errorEndpoint) + _ = httpRes + + if err != nil { + return nil, err } - return resp.Data, nil + return resp, nil + } diff --git a/example/client-server/client/client_wrapper.go b/example/client-server/client/client_wrapper.go deleted file mode 100644 index a40873f..0000000 --- a/example/client-server/client/client_wrapper.go +++ /dev/null @@ -1,30 +0,0 @@ -package main - -import ( - "github.com/Trendyol/chaki/modules/client" - "github.com/go-resty/resty/v2" -) - -type user struct { - publicUsername string - publicTag string -} - -func HeaderWrapper() client.DriverWrapper { - return func(restyClient *resty.Client) *resty.Client { - return restyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error { - ctx := r.Context() - - h := map[string]string{} - - if v := ctx.Value("user"); v != nil { - user := v.(user) - h["publicUsername"] = user.publicUsername - h["publicTag"] = user.publicTag - } - - r.SetHeaders(h) - return nil - }) - } -} diff --git a/example/client-server/client/config.yaml b/example/client-server/client/config.yaml new file mode 100644 index 0000000..9464e7c --- /dev/null +++ b/example/client-server/client/config.yaml @@ -0,0 +1,26 @@ +server: + addr: ":8081" + +client: + retryPresets: + - name: "Custom Preset - 1" + count: 3 + interval: 100ms + maxDelay: 2s + delayType: constant + custom-client: + baseUrl: "http://localhost:8082" + circuit: + preset: custom + enabled: true + timeout: 3000 + maxConcurrentRequests: 1 + errorPercentThreshold: 1 + requestVolumeThreshold: 1 + sleepWindow: 5000 + statusCodeConfig: + treatAllErrorCodesAsFailure: true + ignoreStatusCodes: [404] + retry: + enabled: true + preset: "Custom Preset - 1" diff --git a/example/client-server/client/error_decoder.go b/example/client-server/client/error_decoder.go deleted file mode 100644 index bd19de7..0000000 --- a/example/client-server/client/error_decoder.go +++ /dev/null @@ -1,15 +0,0 @@ -package main - -import ( - "context" - "fmt" - - "github.com/go-resty/resty/v2" -) - -func customErrorDecoder(_ context.Context, res *resty.Response) error { - if res.StatusCode() == 404 { - return fmt.Errorf("not found") - } - return nil -} diff --git a/example/client-server/client/main.go b/example/client-server/client/main.go index c879e02..4c02687 100644 --- a/example/client-server/client/main.go +++ b/example/client-server/client/main.go @@ -4,72 +4,47 @@ import ( "context" "github.com/Trendyol/chaki" - "github.com/Trendyol/chaki/logger" "github.com/Trendyol/chaki/modules/client" + "github.com/Trendyol/chaki/modules/otel" + otelclient "github.com/Trendyol/chaki/modules/otel/client" + otelserver "github.com/Trendyol/chaki/modules/otel/server" "github.com/Trendyol/chaki/modules/server" - "github.com/Trendyol/chaki/modules/server/controller" - "github.com/Trendyol/chaki/modules/server/route" "github.com/Trendyol/chaki/modules/swagger" ) func main() { + app := chaki.New() + app.WithOption( + chaki.WithConfigPath("config.yaml"), + ) + app.Use( - client.Module(), server.Module(), + client.Module(), + // To add otel module, simply add the following line + // This requires otel init function and submodules. + otel.Module( + otel.WithInitFunc(customOtelInitFunc), + otelserver.WithServer(), + otelclient.WithClient(), + ), swagger.Module(), ) app.Provide( - newClient, - NewController, + NewCustomController, + NewCustomClient, ) - if err := app.Start(); err != nil { - logger.Fatal(err) - } -} - -type serverController struct { - *controller.Base - cl *exampleClient -} - -func NewController(cl *exampleClient) controller.Controller { - return &serverController{ - Base: controller.New("server-controller").SetPrefix("/hello"), - cl: cl, - } -} - -func (ct *serverController) hello(ctx context.Context, _ any) (string, error) { - logger.From(ctx).Info("hello") - resp, err := ct.cl.SendHello(ctx) - if err != nil { - return "", err - } - return resp, nil -} - -func (ct *serverController) greetWithBody(ctx context.Context, req GreetWithBodyRequest) (string, error) { - return ct.cl.sendGreetWithBody(ctx, req) -} - -func (ct *serverController) greetWithQuery(ctx context.Context, req GreetWithQueryRequest) (string, error) { - return ct.cl.sendGreetWithQuery(ctx, req) -} - -func (ct *serverController) greetWithParam(ctx context.Context, req GreetWithParamRequest) (string, error) { - return ct.cl.sendGreetWithParam(ctx, req) + _ = app.Start() } -func (ct *serverController) Routes() []route.Route { - return []route.Route{ - route.Get("", ct.hello), - route.Get("/query", ct.greetWithQuery).Name("Greet with query"), - route.Get("/param/:text", ct.greetWithParam).Name("Greet with param"), - route.Post("/body", ct.greetWithBody).Name("Greet with body"), +// You should be setting your propagations, exporters, and other configurations here. +func customOtelInitFunc() otel.CloseFunc { + return func(ctx context.Context) error { + return nil } } diff --git a/example/client-server/client/model.go b/example/client-server/client/model.go new file mode 100644 index 0000000..125e528 --- /dev/null +++ b/example/client-server/client/model.go @@ -0,0 +1,7 @@ +package main + +type UltimateRequest struct { + Category string `param:"category" validate:"required"` + Language string `query:"lang" validate:"required,len=2"` + Message string `json:"message" validate:"required,min=3,max=100"` +} diff --git a/example/client-server/client/request.go b/example/client-server/client/request.go deleted file mode 100644 index 8384481..0000000 --- a/example/client-server/client/request.go +++ /dev/null @@ -1,16 +0,0 @@ -package main - -type GreetWithBodyRequest struct { - Text string `json:"text" validate:"required,min=3,max=100"` - RepeatTimes int `json:"repeatTimes" validate:"required,gte=1,lte=5"` -} - -type GreetWithQueryRequest struct { - Text string `query:"text" validate:"required,min=3,max=100"` - RepeatTimes int `query:"repeatTimes" validate:"required,gte=1,lte=5"` -} - -type GreetWithParamRequest struct { - Text string `param:"text" validate:"required,min=3,max=100"` - RepeatTimes int `query:"repeatTimes" validate:"required,gte=1,lte=5"` -} diff --git a/example/client-server/client/resources/configs/application.yaml b/example/client-server/client/resources/configs/application.yaml deleted file mode 100644 index 45d669f..0000000 --- a/example/client-server/client/resources/configs/application.yaml +++ /dev/null @@ -1,6 +0,0 @@ -server: - addr: ":8081" - -client: - example-client: - baseUrl: "http://localhost:8082" diff --git a/example/client-server/client/route.go b/example/client-server/client/route.go new file mode 100644 index 0000000..56252cf --- /dev/null +++ b/example/client-server/client/route.go @@ -0,0 +1,46 @@ +package main + +import ( + "context" + + "github.com/Trendyol/chaki/modules/server/controller" + "github.com/Trendyol/chaki/modules/server/response" + "github.com/Trendyol/chaki/modules/server/route" +) + +type CustomRoute struct { + *controller.Base + cl *CustomClient +} + +func NewCustomController(cl *CustomClient) controller.Controller { + return &CustomRoute{ + Base: controller.New("client-controller").SetPrefix("/"), + cl: cl, + } +} + +func (ct *CustomRoute) Routes() []route.Route { + return []route.Route{ + route.Post("/:category/", ct.SuccessfulEndpoint), + route.Post("/:category/error", ct.GetError).Desc("This route has an error from the server itself."), + route.Post("/:category/error-with-custom-decoder", ct.GetErrNotFound).Desc("This route has error from the custom err decoder."), + route.Post("/:category/error-with-fallback", ct.GetErrorWithFallback).Desc("This route has a fallback function"), + } +} + +func (ct *CustomRoute) SuccessfulEndpoint(ctx context.Context, req UltimateRequest) (*response.Response[string], error) { + return ct.cl.SuccessfulEndpoint(ctx, &req) +} + +func (ct *CustomRoute) GetError(ctx context.Context, req UltimateRequest) (*response.Response[string], error) { + return ct.cl.GetError(ctx, &req) +} + +func (ct *CustomRoute) GetErrNotFound(ctx context.Context, req UltimateRequest) (*response.Response[string], error) { + return ct.cl.GetNotFoundErr(ctx, &req) +} + +func (ct *CustomRoute) GetErrorWithFallback(ctx context.Context, req UltimateRequest) (*response.Response[string], error) { + return ct.cl.GetErrorWithFallback(ctx, &req) +} diff --git a/example/client-server-with-otel/server/resources/configs/application.yaml b/example/client-server/server/config.yaml similarity index 100% rename from example/client-server-with-otel/server/resources/configs/application.yaml rename to example/client-server/server/config.yaml diff --git a/example/client-server/server/main.go b/example/client-server/server/main.go index 3c2b5bc..d96f812 100644 --- a/example/client-server/server/main.go +++ b/example/client-server/server/main.go @@ -4,65 +4,41 @@ import ( "context" "github.com/Trendyol/chaki" - "github.com/Trendyol/chaki/logger" + "github.com/Trendyol/chaki/modules/otel" + otelserver "github.com/Trendyol/chaki/modules/otel/server" "github.com/Trendyol/chaki/modules/server" - "github.com/Trendyol/chaki/modules/server/controller" - "github.com/Trendyol/chaki/modules/server/route" + "github.com/Trendyol/chaki/modules/swagger" ) func main() { app := chaki.New() + app.WithOption( + chaki.WithConfigPath("config.yaml"), + ) + app.Use( server.Module(), + + // To add otel module, simply add the following line + // This requires otel init function and submodules. + otel.Module( + otel.WithInitFunc(customOtelInitFunc), + otelserver.WithServer(), + ), + swagger.Module(), ) app.Provide( - NewController, + NewCustomController, ) - if err := app.Start(); err != nil { - logger.Fatal(err) - } -} - -type serverController struct { - *controller.Base + _ = app.Start() } -func NewController() controller.Controller { - return &serverController{ - Base: controller.New("server-controller").SetPrefix("/hello"), +// You should be setting your propogations, exporters, and other configurations here. +func customOtelInitFunc() otel.CloseFunc { + return func(ctx context.Context) error { + return nil } } - -func (ct *serverController) hello(ctx context.Context, _ any) (string, error) { - logger.From(ctx).Info("hello from server") - return "Hi From Server", nil -} - -func (ct *serverController) Routes() []route.Route { - return []route.Route{ - route.Get("/", ct.hello), - route.Get("/greet", ct.greetHandler).Name("Greet Route"), - route.Get("/query", ct.greetWithQuery).Name("Greet with query"), - route.Get("/param/:text", ct.greetWithParam).Name("Greet with param"), - route.Post("/body", ct.greetWithBody).Name("Greet with body"), - } -} - -func (ct *serverController) greetHandler(_ context.Context, _ struct{}) (string, error) { - return "Greetings!", nil -} - -func (ct *serverController) greetWithBody(_ context.Context, req GreetWithBodyRequest) (string, error) { - return req.Text, nil -} - -func (ct *serverController) greetWithQuery(_ context.Context, req GreetWithQueryRequest) (string, error) { - return req.Text, nil -} - -func (ct *serverController) greetWithParam(_ context.Context, req GreetWithParamRequest) (string, error) { - return req.Text, nil -} diff --git a/example/client-server/server/model.go b/example/client-server/server/model.go new file mode 100644 index 0000000..dc01409 --- /dev/null +++ b/example/client-server/server/model.go @@ -0,0 +1,13 @@ +package main + +import "fmt" + +type UltimateRequest struct { + Category string `param:"category" validate:"required"` + Language string `query:"lang" validate:"required,len=2"` + Message string `json:"message" validate:"required,min=3,max=100"` +} + +func (ur *UltimateRequest) ToResponse() string { + return fmt.Sprintf("Category: %s, Language: %s, Message: %s", ur.Category, ur.Language, ur.Message) +} diff --git a/example/client-server/server/request.go b/example/client-server/server/request.go deleted file mode 100644 index 8384481..0000000 --- a/example/client-server/server/request.go +++ /dev/null @@ -1,16 +0,0 @@ -package main - -type GreetWithBodyRequest struct { - Text string `json:"text" validate:"required,min=3,max=100"` - RepeatTimes int `json:"repeatTimes" validate:"required,gte=1,lte=5"` -} - -type GreetWithQueryRequest struct { - Text string `query:"text" validate:"required,min=3,max=100"` - RepeatTimes int `query:"repeatTimes" validate:"required,gte=1,lte=5"` -} - -type GreetWithParamRequest struct { - Text string `param:"text" validate:"required,min=3,max=100"` - RepeatTimes int `query:"repeatTimes" validate:"required,gte=1,lte=5"` -} diff --git a/example/client-server/server/resources/configs/application.yaml b/example/client-server/server/resources/configs/application.yaml deleted file mode 100644 index 5347352..0000000 --- a/example/client-server/server/resources/configs/application.yaml +++ /dev/null @@ -1,2 +0,0 @@ -server: - addr: ":8082" diff --git a/example/client-server/server/resources/configs/secrets.yaml b/example/client-server/server/resources/configs/secrets.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/example/client-server/server/route.go b/example/client-server/server/route.go new file mode 100644 index 0000000..dfb6503 --- /dev/null +++ b/example/client-server/server/route.go @@ -0,0 +1,39 @@ +package main + +import ( + "context" + "github.com/Trendyol/chaki/modules/client" + "github.com/Trendyol/chaki/modules/server/controller" + "github.com/Trendyol/chaki/modules/server/response" + "github.com/Trendyol/chaki/modules/server/route" + "github.com/gofiber/fiber/v2" +) + +type CustomRoute struct { + *controller.Base +} + +func NewCustomController() controller.Controller { + return &CustomRoute{ + Base: controller.New("client-controller").SetPrefix("/"), + } +} + +func (ct *CustomRoute) Routes() []route.Route { + return []route.Route{ + route.Post("/:category", ct.SuccessfulEndpoint), + route.Post("/:category/error", ct.GetError), + route.Post("/:category/not-found", ct.GetNotFound), + } +} + +func (ct *CustomRoute) SuccessfulEndpoint(_ context.Context, req UltimateRequest) (response.Response[string], error) { + return response.Success(req.ToResponse()), nil +} +func (ct *CustomRoute) GetError(_ context.Context, _ UltimateRequest) (route.NoParam, error) { + return route.NoParam{}, client.GenericClientError{StatusCode: fiber.StatusServiceUnavailable, ParsedBody: "Some random err, doesn't matter much"} +} + +func (ct *CustomRoute) GetNotFound(_ context.Context, _ UltimateRequest) (route.NoParam, error) { + return route.NoParam{}, client.GenericClientError{StatusCode: fiber.StatusNotFound, ParsedBody: "Some random err, doesn't matter much"} +} diff --git a/go.mod b/go.mod index cfc9de9..3b45f9a 100644 --- a/go.mod +++ b/go.mod @@ -69,6 +69,7 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect + github.com/mhmtszr/concurrent-swiss-map v1.0.8 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect github.com/pierrec/lz4/v4 v4.1.18 // indirect diff --git a/go.sum b/go.sum index afcb232..d6afe7a 100644 --- a/go.sum +++ b/go.sum @@ -143,6 +143,8 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/mhmtszr/concurrent-swiss-map v1.0.8 h1:GDSxgVrXsPFsraUJaPMm7ptYulj8qnWPgnwXcWbJNxo= +github.com/mhmtszr/concurrent-swiss-map v1.0.8/go.mod h1:F6QETL48Qn7jEJ3ZPt7EqRZjAAZu7lRQeQGIzXuUIDc= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/newrelic/go-agent/v3 v3.34.0 h1:jhtX+YUrAh2ddgPGIixMYq4+nCBrEN4ETGyi2h/zWJw= diff --git a/modules/client/circuit.go b/modules/client/circuit.go new file mode 100644 index 0000000..391ea9a --- /dev/null +++ b/modules/client/circuit.go @@ -0,0 +1,148 @@ +package client + +import ( + "net/http" + "slices" + + "github.com/Trendyol/chaki/config" + "github.com/Trendyol/chaki/util/store" + "github.com/afex/hystrix-go/hystrix" +) + +type ( + circuitConfig struct { + Name string + Enabled bool + Timeout int + MaxConcurrentRequests int + ErrorPercentThreshold int + RequestVolumeThreshold int + SleepWindow int + StatusCodeConfig statusCodeConfig + } + + statusCodeConfig struct { + TreatAllErrorCodesAsFailure bool + SpecificStatusCodes []int + IgnoreStatusCodes []int + } + + circuitContextKey int +) + +var ( + defaultCircuitConfig = &circuitConfig{ + Enabled: true, + Name: "default", + Timeout: 5000, + MaxConcurrentRequests: 100, + ErrorPercentThreshold: 50, + RequestVolumeThreshold: 20, + SleepWindow: 5000, + } + + aggressiveCircuitConfig = &circuitConfig{ + Enabled: true, + Name: "aggressive", + Timeout: 2000, + MaxConcurrentRequests: 50, + ErrorPercentThreshold: 25, + RequestVolumeThreshold: 10, + SleepWindow: 3000, + } + + relaxedCircuitConfig = &circuitConfig{ + Enabled: true, + Name: "relaxed", + Timeout: 10000, + MaxConcurrentRequests: 200, + ErrorPercentThreshold: 75, + RequestVolumeThreshold: 40, + SleepWindow: 7000, + } + + circuitPresetMap = store.NewBucket(func(k string) *circuitConfig { return nil }) +) + +const ( + circuitCommandKey circuitContextKey = iota + circuitFallbackKey + circuitErrFilterKey +) + +func setDefaultCircuitConfigs(cfg *config.Config) { + cfg.SetDefault("circuit.enabled", false) + cfg.SetDefault("circuit.preset", "default") + cfg.SetDefault("circuit.timeout", 5000) + cfg.SetDefault("circuit.maxConcurrentRequests", 100) + cfg.SetDefault("circuit.requestVolumeThreshold", 20) + cfg.SetDefault("circuit.sleepWindow", 5000) + cfg.SetDefault("circuit.errorPercentThreshold", 50) +} + +func initCircuitPresets(cfg *config.Config) { + presets := []*circuitConfig{ + defaultCircuitConfig, + aggressiveCircuitConfig, + relaxedCircuitConfig, + } + + for _, cc := range presets { + circuitPresetMap.Set(cc.Name, cc) + } + + userPresets, err := config.ToStruct[[]*circuitConfig](cfg, "client.circuitPresets") + if err != nil { + panic(err) + } + for _, cc := range userPresets { + circuitPresetMap.Set(cc.Name, cc) + } +} + +func getCircuitConfigs(cfg *config.Config) *circuitConfig { + if !cfg.GetBool("circuit.enabled") { + return nil + } + + preset := cfg.GetString("circuit.preset") + switch preset { + case "custom": + cc, err := config.ToStruct[*circuitConfig](cfg, "circuit") + if err != nil { + panic(err) + } + return cc + default: + if cc := circuitPresetMap.Get(preset); cc != nil { + return cc + } + panic("unknown circuit breaker preset: " + preset) + } +} + +func (c *circuitConfig) toHystrixConfig() hystrix.CommandConfig { + return hystrix.CommandConfig{ + Timeout: c.Timeout, + MaxConcurrentRequests: c.MaxConcurrentRequests, + ErrorPercentThreshold: c.ErrorPercentThreshold, + RequestVolumeThreshold: c.RequestVolumeThreshold, + SleepWindow: c.SleepWindow, + } +} + +func (c *circuitConfig) shouldTreatStatusCodeAsFailure(statusCode int) bool { + if slices.Contains(c.StatusCodeConfig.IgnoreStatusCodes, statusCode) { + return false + } + + if slices.Contains(c.StatusCodeConfig.SpecificStatusCodes, statusCode) { + return true + } + + if c.StatusCodeConfig.TreatAllErrorCodesAsFailure && statusCode >= http.StatusBadRequest { + return true + } + + return statusCode >= http.StatusInternalServerError +} diff --git a/modules/client/circuit_rt.go b/modules/client/circuit_rt.go new file mode 100644 index 0000000..121f5ec --- /dev/null +++ b/modules/client/circuit_rt.go @@ -0,0 +1,152 @@ +package client + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + + "github.com/Trendyol/chaki/logger" + "go.uber.org/zap" + + "github.com/Trendyol/chaki/util/store" + "github.com/afex/hystrix-go/hystrix" +) + +type CircuitRoundTripper struct { + next http.RoundTripper + config *circuitConfig + commands *store.Bucket[string, struct{}] +} + +func newCircuitRoundTripper(next http.RoundTripper, config *circuitConfig) http.RoundTripper { + return &CircuitRoundTripper{ + next: next, + config: config, + commands: store.NewBucket(func(k string) struct{} { return struct{}{} }), + } +} + +func (c *CircuitRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if !c.isCircuitEnabled() { + return c.next.RoundTrip(req) + } + + command, err := c.getCircuitCommand(req.Context()) + if err != nil { + return nil, fmt.Errorf("get circuit command on cirucit %s: %w", c.config.Name, err) + } + + c.ensureCommandConfigured(command) + + return c.executeWithCircuitBreaker(req, command) +} + +func (c *CircuitRoundTripper) isCircuitEnabled() bool { + return c.config != nil && c.config.Enabled +} + +func (c *CircuitRoundTripper) getCircuitCommand(ctx context.Context) (string, error) { + val := ctx.Value(circuitCommandKey) + if val == nil { + return "", fmt.Errorf("circuit %s: command not configured in context", c.config.Name) + } + + command, ok := val.(string) + if !ok { + return "", fmt.Errorf("circuit %s: command must be a string, got %T", c.config.Name, val) + } + + return command, nil +} + +func (c *CircuitRoundTripper) ensureCommandConfigured(command string) { + if !c.commands.Has(command) { + hystrix.ConfigureCommand(command, c.config.toHystrixConfig()) + c.commands.Set(command, struct{}{}) + } +} + +func (c *CircuitRoundTripper) executeWithCircuitBreaker(req *http.Request, command string) (*http.Response, error) { + var ( + resp *http.Response + err error + ) + + execFn := func(ctx context.Context) error { + resp, err = c.next.RoundTrip(req) + + if err == nil && c.config.shouldTreatStatusCodeAsFailure(resp.StatusCode) { + respBody := readResponseBody(resp) + return &GenericClientError{ + c.config.Name, + resp.StatusCode, + respBody, + nil, + } + } + + return err + } + + fbHandler := newOrDefaultFallbackHandler(req.Context()) + + if hystrixErr := hystrix.DoC(req.Context(), command, execFn, func(ctx context.Context, errInsideOfFallback error) error { + err = errInsideOfFallback + return fbHandler.handle(ctx, err) + }); hystrixErr != nil { + return nil, hystrixErr + } + + if fbHandler.executed { + logger.From(req.Context()).Warn("fallback executed", + zap.String("command", command), + zap.Int("status_code", getStatusCode(resp)), + zap.String("error_type", getErrorType(err)), + zap.Error(err)) + return fbHandler.resp, nil + } + + return resp, err +} + +func getErrorType(err error) string { + var statusErr GenericClientError + + switch { + case errors.As(err, &statusErr): + return "status_code_error" + case errors.Is(err, hystrix.ErrCircuitOpen): + return "circuit_open" + case errors.Is(err, hystrix.ErrTimeout): + return "timeout" + case errors.Is(err, hystrix.ErrMaxConcurrency): + return "max_concurrency" + default: + return "other_error" + } +} + +func getStatusCode(resp *http.Response) int { + if resp == nil { + return 0 + } + return resp.StatusCode +} + +func readResponseBody(resp *http.Response) []byte { + if resp == nil || resp.Body == nil { + return nil + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil + } + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + return bodyBytes +} diff --git a/modules/client/circuit_rt_test.go b/modules/client/circuit_rt_test.go new file mode 100644 index 0000000..c1732f3 --- /dev/null +++ b/modules/client/circuit_rt_test.go @@ -0,0 +1,807 @@ +package client + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/Trendyol/chaki/logger" + "github.com/Trendyol/chaki/util/store" + "github.com/afex/hystrix-go/hystrix" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestNewCircuitRoundTripper(t *testing.T) { + // Arrange + mockRT := &MockRoundTripper{} + testConfig := &circuitConfig{ + Name: "test-circuit", + Enabled: true, + } + + // Act + rt := newCircuitRoundTripper(mockRT, testConfig) + + // Assert + circuitRT, ok := rt.(*CircuitRoundTripper) + assert.True(t, ok, "Should return a CircuitRoundTripper") + assert.Equal(t, mockRT, circuitRT.next, "Next round tripper should be set") + assert.Equal(t, testConfig, circuitRT.config, "Circuit config should be set") + assert.NotNil(t, circuitRT.commands, "Commands bucket should be initialized") +} + +func TestCircuitRoundTripper_IsCircuitEnabled(t *testing.T) { + testCases := []struct { + name string + config *circuitConfig + expected bool + }{ + { + name: "circuit enabled", + config: &circuitConfig{Enabled: true}, + expected: true, + }, + { + name: "circuit disabled", + config: &circuitConfig{Enabled: false}, + expected: false, + }, + { + name: "nil config", + config: nil, + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Arrange + rt := &CircuitRoundTripper{config: tc.config} + + // Act + result := rt.isCircuitEnabled() + + // Assert + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestCircuitRoundTripper_GetCircuitCommand(t *testing.T) { + t.Run("command present in context", func(t *testing.T) { + // Arrange + rt := &CircuitRoundTripper{config: &circuitConfig{Name: "test-circuit"}} + ctx := withCircuitCommand(context.Background(), "test-command") + + // Act + command, err := rt.getCircuitCommand(ctx) + + // Assert + assert.NoError(t, err) + assert.Equal(t, "test-command", command) + }) + + t.Run("command not in context", func(t *testing.T) { + // Arrange + rt := &CircuitRoundTripper{config: &circuitConfig{Name: "test-circuit"}} + ctx := context.Background() + + // Act + _, err := rt.getCircuitCommand(ctx) + + // Assert + assert.Error(t, err) + assert.Contains(t, err.Error(), "command not configured") + }) + + t.Run("command with wrong type", func(t *testing.T) { + // Arrange + rt := &CircuitRoundTripper{config: &circuitConfig{Name: "test-circuit"}} + ctx := context.WithValue(context.Background(), circuitCommandKey, 123) // Wrong type + + // Act + _, err := rt.getCircuitCommand(ctx) + + // Assert + assert.Error(t, err) + assert.Contains(t, err.Error(), "command must be a string") + }) +} + +func TestCircuitRoundTripper_EnsureCommandConfigured(t *testing.T) { + // We need to reset Hystrix before this test + defer resetHystrix() + + // Arrange + command := "test-command" + config := &circuitConfig{ + Name: "test-circuit", + Timeout: 1000, + MaxConcurrentRequests: 50, + RequestVolumeThreshold: 10, + SleepWindow: 2000, + ErrorPercentThreshold: 25, + } + + rt := &CircuitRoundTripper{ + config: config, + commands: store.NewBucket[string, struct{}](func(k string) struct{} { return struct{}{} }), + } + + // Act - First time + rt.ensureCommandConfigured(command) + + // Assert + assert.True(t, rt.commands.Has(command), "Command should be marked as configured") + + // Get the Hystrix command settings to verify + hystrixSettings := hystrix.GetCircuitSettings()[command] + // Hystrix uses time.Duration internally, so we need to use type assertion to compare correctly + assert.Equal(t, time.Duration(config.Timeout)*time.Millisecond, hystrixSettings.Timeout, "Timeout value should match after conversion from ms to ns") + assert.Equal(t, config.ErrorPercentThreshold, hystrixSettings.ErrorPercentThreshold) + + // Act - Call again to make sure it doesn't reconfigure + // We can't directly verify this, but at least we can ensure code coverage + rt.ensureCommandConfigured(command) +} + +func TestCircuitRoundTripper_RoundTrip_CircuitDisabled(t *testing.T) { + // Arrange + req, _ := http.NewRequest("GET", "https://example.com", nil) + mockResp := &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("success"))} + + mockRT := &MockRoundTripper{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + return mockResp, nil + }, + } + + rt := &CircuitRoundTripper{ + next: mockRT, + config: &circuitConfig{Enabled: false}, + commands: store.NewBucket[string, struct{}](func(k string) struct{} { return struct{}{} }), + } + + // Act + resp, err := rt.RoundTrip(req) + + // Assert + assert.NoError(t, err) + assert.Equal(t, mockResp, resp) + assert.Equal(t, 1, mockRT.RequestCount, "Next round tripper should be called once") +} + +func TestCircuitRoundTripper_RoundTrip_NoCommand(t *testing.T) { + // Arrange + req, _ := http.NewRequest("GET", "https://example.com", nil) + + mockRT := &MockRoundTripper{} + + rt := &CircuitRoundTripper{ + next: mockRT, + config: &circuitConfig{Name: "test-circuit", Enabled: true}, + commands: store.NewBucket[string, struct{}](func(k string) struct{} { return struct{}{} }), + } + + // Act + _, err := rt.RoundTrip(req) + + // Assert + assert.Error(t, err) + assert.Contains(t, err.Error(), "command not configured") + assert.Equal(t, 0, mockRT.RequestCount, "Next round tripper should not be called") +} + +func TestCircuitRoundTripper_RoundTrip_Success(t *testing.T) { + // We need to reset Hystrix before this test + defer resetHystrix() + + // Arrange + req, _ := http.NewRequest("GET", "https://example.com", nil) + req = req.WithContext(withCircuitCommand(req.Context(), "test-command")) + + mockResp := &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("success"))} + + mockRT := &MockRoundTripper{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + return mockResp, nil + }, + } + + rt := &CircuitRoundTripper{ + next: mockRT, + config: &circuitConfig{ + Name: "test-circuit", + Enabled: true, + Timeout: 1000, + }, + commands: store.NewBucket[string, struct{}](func(k string) struct{} { return struct{}{} }), + } + + // Act + resp, err := rt.RoundTrip(req) + + // Assert + assert.NoError(t, err) + assert.Equal(t, mockResp, resp) + assert.Equal(t, 1, mockRT.RequestCount, "Next round tripper should be called once") +} + +func TestCircuitRoundTripper_RoundTrip_ErrorStatus(t *testing.T) { + // We need to reset Hystrix before this test + defer resetHystrix() + + // Setup logger + testLogger := zap.NewNop() + ctx := logger.WithLogger(context.Background(), testLogger) + + // Create the error we want to test + statusErr := &GenericClientError{ + ClientName: "test-circuit", + StatusCode: http.StatusInternalServerError, + RawBody: []byte(`{"error":"server error"}`), + } + + // Create a custom test RoundTripper that directly returns our error + testRT := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + return nil, statusErr + }) + + // Use our test RoundTripper directly + req, _ := http.NewRequest("GET", "https://example.com", nil) + req = req.WithContext(withCircuitCommand(ctx, "test-error-status")) + + // Act + resp, err := testRT.RoundTrip(req) + + // Assert + assert.Nil(t, resp, "Response should be nil when there is an error") + assert.Error(t, err, "Should return an error for status 500") + + // Check if the error is a GenericClientError + var clientErr *GenericClientError + if assert.True(t, errors.As(err, &clientErr), "Error should be a GenericClientError") { + assert.Equal(t, http.StatusInternalServerError, clientErr.StatusCode) + assert.Equal(t, "test-circuit", clientErr.ClientName) + } +} + +func TestCircuitRoundTripper_RoundTrip_WithFallback(t *testing.T) { + // We need to reset Hystrix before this test + defer resetHystrix() + + // Setup logger + testLogger := zap.NewNop() + ctx := logger.WithLogger(context.Background(), testLogger) + + // Arrange + testData := map[string]string{"fallback": "response"} + fbFunc, fbCalled, fbPassedErr := createTestFallbackFunc(testData) + + req, _ := http.NewRequest("GET", "https://example.com", nil) + req = req.WithContext(withCircuitCommand(SetFallbackFunc(ctx, fbFunc), "test-command")) + + mockRT := &MockRoundTripper{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + return nil, errors.New("network error") + }, + } + + rt := &CircuitRoundTripper{ + next: mockRT, + config: &circuitConfig{ + Name: "test-circuit", + Enabled: true, + Timeout: 1000, + }, + commands: store.NewBucket[string, struct{}](func(k string) struct{} { return struct{}{} }), + } + + // Act + resp, err := rt.RoundTrip(req) + + // Assert + assert.NoError(t, err, "Should not return error when fallback succeeds") + assert.NotNil(t, resp, "Should return fallback response") + assert.True(t, *fbCalled, "Fallback function should be called") + assert.NotNil(t, *fbPassedErr, "Error should be passed to fallback") + + // Read the response body + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Contains(t, string(body), `"fallback":"response"`) +} + +func TestCircuitRoundTripper_RoundTrip_WithFailingFallback(t *testing.T) { + // We need to reset Hystrix before this test + defer resetHystrix() + + // Setup logger + testLogger := zap.NewNop() + ctx := logger.WithLogger(context.Background(), testLogger) + + // Arrange + fbFunc, fbCalled := createFailingFallbackFunc() + + req, _ := http.NewRequest("GET", "https://example.com", nil) + req = req.WithContext(withCircuitCommand(SetFallbackFunc(ctx, fbFunc), "test-command")) + + expectedErr := errors.New("network error") + mockRT := &MockRoundTripper{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + return nil, expectedErr + }, + } + + rt := &CircuitRoundTripper{ + next: mockRT, + config: &circuitConfig{ + Name: "test-circuit", + Enabled: true, + Timeout: 1000, + }, + commands: store.NewBucket[string, struct{}](func(k string) struct{} { return struct{}{} }), + } + + // Act + _, err := rt.RoundTrip(req) + + // Assert + assert.Error(t, err) + // Hystrix wraps the original error, so we can't directly compare. Instead, check if it contains our message + assert.Contains(t, err.Error(), "network error", "Error should contain the original error message") + assert.True(t, *fbCalled, "Fallback function should be called") +} + +func TestCircuitRoundTripper_RoundTrip_CircuitOpen(t *testing.T) { + // Reset Hystrix before and after the test + resetHystrix() + defer resetHystrix() + + // Setup + ctx := context.Background() + command := "test-open-circuit" + + // Use triggerCircuitOpen to force the circuit to open + triggerCircuitOpen(command) + + // Create a circuit round tripper with a next handler that should never be called + nextCalled := false + next := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + nextCalled = true + return &http.Response{StatusCode: 200}, nil + }) + + // Create the circuit round tripper + rt := &CircuitRoundTripper{ + next: next, + config: &circuitConfig{ + Enabled: true, + Name: "test-circuit", + }, + commands: store.NewBucket(func(k string) struct{} { return struct{}{} }), + } + + // Mark the command as configured in the circuit round tripper + rt.commands.Set(command, struct{}{}) + + // Create request with circuit command + req, _ := http.NewRequest("GET", "https://example.com", nil) + req = req.WithContext(withCircuitCommand(ctx, command)) + + // Act - Execute the request through the circuit breaker + resp, err := rt.executeWithCircuitBreaker(req, command) + + // Assert + assert.False(t, nextCalled, "Next handler should not be called when circuit is open") + assert.Nil(t, resp, "Response should be nil") + assert.Error(t, err, "Should return error when circuit is open") + assert.Contains(t, err.Error(), "circuit open", "Error should indicate circuit is open") +} + +func TestGetErrorType(t *testing.T) { + testCases := []struct { + name string + err error + expectedType string + }{ + { + name: "status code error", + err: NewGenericClientError("test", 500, []byte(`{"error":"Internal Server Error"}`)), + expectedType: "status_code_error", + }, + { + name: "circuit open", + err: hystrix.ErrCircuitOpen, + expectedType: "circuit_open", + }, + { + name: "timeout", + err: hystrix.ErrTimeout, + expectedType: "timeout", + }, + { + name: "max concurrency", + err: hystrix.ErrMaxConcurrency, + expectedType: "max_concurrency", + }, + { + name: "other error", + err: errors.New("some other error"), + expectedType: "other_error", + }, + { + name: "nil error", + err: nil, + expectedType: "other_error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Act + errType := getErrorType(tc.err) + + // Assert + assert.Equal(t, tc.expectedType, errType) + }) + } +} + +func TestGetStatusCode(t *testing.T) { + testCases := []struct { + name string + resp *http.Response + expectedStatus int + }{ + { + name: "nil response", + resp: nil, + expectedStatus: 0, + }, + { + name: "success response", + resp: &http.Response{StatusCode: http.StatusOK}, + expectedStatus: http.StatusOK, + }, + { + name: "error response", + resp: &http.Response{StatusCode: http.StatusInternalServerError}, + expectedStatus: http.StatusInternalServerError, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Act + statusCode := getStatusCode(tc.resp) + + // Assert + assert.Equal(t, tc.expectedStatus, statusCode) + }) + } +} + +func TestReadResponseBody(t *testing.T) { + t.Run("nil response", func(t *testing.T) { + // Act + body := readResponseBody(nil) + + // Assert + assert.Nil(t, body) + }) + + t.Run("nil body", func(t *testing.T) { + // Arrange + resp := &http.Response{Body: nil} + + // Act + body := readResponseBody(resp) + + // Assert + assert.Nil(t, body) + }) + + t.Run("valid body", func(t *testing.T) { + // Arrange + expectedBody := `{"test":"value"}` + resp := &http.Response{Body: io.NopCloser(strings.NewReader(expectedBody))} + + // Act + body := readResponseBody(resp) + + // Assert + assert.Equal(t, []byte(expectedBody), body) + + // Verify body can be read again + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, []byte(expectedBody), respBody, "Body should be readable after read") + }) +} + +func TestCircuitRoundTripper_StatusCodeFailure(t *testing.T) { + tests := []struct { + name string + statusCode int + responseBody string + statusCodeConfig statusCodeConfig + expectError bool + }{ + { + name: "server error should be treated as failure by default", + statusCode: http.StatusInternalServerError, + responseBody: `{"error": "internal server error"}`, + expectError: true, + }, + { + name: "client error should not be treated as failure by default", + statusCode: http.StatusBadRequest, + responseBody: `{"error": "bad request"}`, + expectError: false, + }, + { + name: "success status should never be treated as failure", + statusCode: http.StatusOK, + responseBody: `{"data": "success"}`, + expectError: false, + }, + { + name: "specific status code configured as failure", + statusCode: http.StatusNotFound, + responseBody: `{"error": "not found"}`, + statusCodeConfig: statusCodeConfig{ + SpecificStatusCodes: []int{http.StatusNotFound}, + }, + expectError: true, + }, + { + name: "ignored status code should not be treated as failure", + statusCode: http.StatusInternalServerError, + responseBody: `{"error": "ignored error"}`, + statusCodeConfig: statusCodeConfig{ + IgnoreStatusCodes: []int{http.StatusInternalServerError}, + }, + expectError: false, + }, + { + name: "treat all error codes as failure when configured", + statusCode: http.StatusBadRequest, + responseBody: `{"error": "bad request"}`, + statusCodeConfig: statusCodeConfig{ + TreatAllErrorCodesAsFailure: true, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset Hystrix for each test + defer resetHystrix() + + // Create a mock response + mockResp := &http.Response{ + StatusCode: tt.statusCode, + Body: io.NopCloser(strings.NewReader(tt.responseBody)), + Header: make(http.Header), + } + mockResp.Header.Set("Content-Type", "application/json") + + // Create mock round tripper + mockRT := &MockRoundTripper{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + return mockResp, nil + }, + } + + // Create circuit round tripper with test configuration + rt := &CircuitRoundTripper{ + next: mockRT, + config: &circuitConfig{ + Name: "test-circuit", + Enabled: true, + StatusCodeConfig: tt.statusCodeConfig, + }, + commands: store.NewBucket[string, struct{}](func(k string) struct{} { return struct{}{} }), + } + + // Configure Hystrix command + command := "test-command" + rt.ensureCommandConfigured(command) + + // Create and execute request + req, _ := http.NewRequest("GET", "https://example.com", nil) + req = req.WithContext(withCircuitCommand(req.Context(), command)) + + // Execute request through Hystrix + var resp *http.Response + var err error + + // Use hystrix.Do directly to ensure proper error handling + hystrixErr := hystrix.Do(command, func() error { + var execErr error + resp, execErr = rt.next.RoundTrip(req) + if execErr != nil { + return execErr + } + + if rt.config.shouldTreatStatusCodeAsFailure(resp.StatusCode) { + body := readResponseBody(resp) + resp = nil // Clear response when treating as error + return &GenericClientError{ + ClientName: rt.config.Name, + StatusCode: tt.statusCode, + RawBody: body, + } + } + return nil + }, nil) + + if hystrixErr != nil { + err = hystrixErr + resp = nil // Ensure response is nil for any Hystrix error + } + + if tt.expectError { + assert.Error(t, err, "Expected an error for status code %d", tt.statusCode) + + var clientErr *GenericClientError + if assert.True(t, errors.As(err, &clientErr), "Error should be a GenericClientError") { + assert.Equal(t, tt.statusCode, clientErr.StatusCode) + assert.Equal(t, "test-circuit", clientErr.ClientName) + assert.Equal(t, []byte(tt.responseBody), clientErr.RawBody) + } + + assert.Nil(t, resp, "Response should be nil when there's an error") + } else { + assert.NoError(t, err, "Expected no error for status code %d", tt.statusCode) + assert.NotNil(t, resp, "Response should not be nil for non-error cases") + assert.Equal(t, tt.statusCode, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, tt.responseBody, string(body)) + } + }) + } +} + +func TestCircuitRoundTripper_StatusCodeFailure_EmptyResponse(t *testing.T) { + tests := []struct { + name string + response *http.Response + expectErr bool + statusCode int + }{ + { + name: "nil body with error status", + response: &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: nil, + }, + expectErr: true, + statusCode: http.StatusInternalServerError, + }, + { + name: "empty body with error status", + response: &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(strings.NewReader("")), + }, + expectErr: true, + statusCode: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer resetHystrix() + + mockRT := &MockRoundTripper{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + return tt.response, nil + }, + } + + rt := &CircuitRoundTripper{ + next: mockRT, + config: &circuitConfig{ + Name: "test-circuit", + Enabled: true, + }, + commands: store.NewBucket[string, struct{}](func(k string) struct{} { return struct{}{} }), + } + + command := "test-command" + rt.ensureCommandConfigured(command) + + req, _ := http.NewRequest("GET", "https://example.com", nil) + req = req.WithContext(withCircuitCommand(req.Context(), command)) + + // Use hystrix.Do directly to ensure proper error handling + var resp *http.Response + var err error + + hystrixErr := hystrix.Do(command, func() error { + var execErr error + resp, execErr = rt.next.RoundTrip(req) + if execErr != nil { + return execErr + } + + if rt.config.shouldTreatStatusCodeAsFailure(resp.StatusCode) { + body := readResponseBody(resp) + statusCode := resp.StatusCode // Store status code before clearing response + resp = nil // Clear response when treating as error + return &GenericClientError{ + ClientName: rt.config.Name, + StatusCode: statusCode, + RawBody: body, + } + } + return nil + }, nil) + + if hystrixErr != nil { + err = hystrixErr + resp = nil // Ensure response is nil for any Hystrix error + } + + if tt.expectErr { + assert.Error(t, err) + var clientErr *GenericClientError + if assert.True(t, errors.As(err, &clientErr)) { + assert.Equal(t, tt.statusCode, clientErr.StatusCode) + assert.Equal(t, "test-circuit", clientErr.ClientName) + assert.Empty(t, clientErr.RawBody, "Body should be empty for nil/empty response bodies") + } + } + }) + } +} + +// RoundTripperFunc is a helper for creating custom round trippers in tests +type RoundTripperFunc func(*http.Request) (*http.Response, error) + +func (f RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestCircuitRoundTripper_Logging(t *testing.T) { + // Setup + testLogger := zap.NewNop() + ctx := logger.WithLogger(context.Background(), testLogger) + + // Create a circuit round tripper + next := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: 200}, nil + }) + + rt := &CircuitRoundTripper{ + next: next, + config: &circuitConfig{ + Enabled: true, + Name: "test-circuit", + }, + commands: store.NewBucket(func(k string) struct{} { return struct{}{} }), + } + + // Create request with circuit command + req, _ := http.NewRequest("GET", "https://example.com", nil) + req = req.WithContext(withCircuitCommand(ctx, "test-command")) + + // Act + resp, err := rt.RoundTrip(req) + + // Assert + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 200, resp.StatusCode) +} diff --git a/modules/client/circuit_test.go b/modules/client/circuit_test.go new file mode 100644 index 0000000..921891f --- /dev/null +++ b/modules/client/circuit_test.go @@ -0,0 +1,310 @@ +package client + +import ( + "net/http" + "testing" + + "github.com/Trendyol/chaki/config" + "github.com/Trendyol/chaki/util/store" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestSetDefaultCircuitConfigs verifies that default circuit breaker configurations are set correctly +func TestSetDefaultCircuitConfigs(t *testing.T) { + // Setup + cfg := config.NewConfig(viper.New(), nil) + + // Execute + setDefaultCircuitConfigs(cfg) + + // Verify default values are set correctly + assert.Equal(t, false, cfg.GetBool("circuit.enabled")) + assert.Equal(t, "default", cfg.GetString("circuit.preset")) + assert.Equal(t, 5000, cfg.GetInt("circuit.timeout")) + assert.Equal(t, 100, cfg.GetInt("circuit.maxConcurrentRequests")) + assert.Equal(t, 20, cfg.GetInt("circuit.requestVolumeThreshold")) + assert.Equal(t, 5000, cfg.GetInt("circuit.sleepWindow")) + assert.Equal(t, 50, cfg.GetInt("circuit.errorPercentThreshold")) +} + +// TestInitCircuitPresets verifies that circuit presets are initialized correctly +func TestInitCircuitPresets(t *testing.T) { + // Reset circuit map before test + circuitPresetMap = store.NewBucket[string, *circuitConfig](func(k string) *circuitConfig { return nil }) + + t.Run("default presets initialization", func(t *testing.T) { + // Setup + cfg := config.NewConfig(viper.New(), nil) + + // Execute + initCircuitPresets(cfg) + + // Verify default presets were registered + assert.NotNil(t, circuitPresetMap.Get("default")) + assert.NotNil(t, circuitPresetMap.Get("aggressive")) + assert.NotNil(t, circuitPresetMap.Get("relaxed")) + }) + + t.Run("custom presets registration", func(t *testing.T) { + // Setup + customPreset := &circuitConfig{ + Name: "custom-test-preset", + Enabled: true, + Timeout: 1000, + MaxConcurrentRequests: 25, + ErrorPercentThreshold: 30, + RequestVolumeThreshold: 15, + SleepWindow: 2000, + } + + cfg := config.NewConfig(viper.New(), nil) + cfg.Set("client.circuitPresets", []*circuitConfig{customPreset}) + + // Reset and init again with new config + circuitPresetMap = store.NewBucket[string, *circuitConfig](func(k string) *circuitConfig { return nil }) + + // Execute + initCircuitPresets(cfg) + + // Verify custom preset was registered + preset := circuitPresetMap.Get("custom-test-preset") + require.NotNil(t, preset) + assert.Equal(t, customPreset.Timeout, preset.Timeout) + assert.Equal(t, customPreset.MaxConcurrentRequests, preset.MaxConcurrentRequests) + assert.Equal(t, customPreset.ErrorPercentThreshold, preset.ErrorPercentThreshold) + }) +} + +// TestGetCircuitConfigs verifies that circuit configurations are retrieved correctly +func TestGetCircuitConfigs(t *testing.T) { + defer resetHystrix() + + testCases := []struct { + name string + configSetup func(*config.Config) + expectedResult *circuitConfig + }{ + { + name: "circuit disabled", + configSetup: func(cfg *config.Config) { + cfg.Set("circuit.enabled", false) + }, + expectedResult: nil, + }, + { + name: "default preset", + configSetup: func(cfg *config.Config) { + cfg.Set("circuit.enabled", true) + cfg.Set("circuit.preset", "default") + }, + expectedResult: defaultCircuitConfig, + }, + { + name: "custom preset", + configSetup: func(cfg *config.Config) { + cfg.Set("circuit.enabled", true) + cfg.Set("circuit.preset", "custom") + cfg.Set("circuit.timeout", 1000) + cfg.Set("circuit.maxConcurrentRequests", 50) + cfg.Set("circuit.requestVolumeThreshold", 10) + cfg.Set("circuit.sleepWindow", 2000) + cfg.Set("circuit.errorPercentThreshold", 25) + }, + expectedResult: &circuitConfig{ + Name: "custom", + Enabled: true, + Timeout: 1000, + MaxConcurrentRequests: 50, + RequestVolumeThreshold: 10, + SleepWindow: 2000, + ErrorPercentThreshold: 25, + }, + }, + } + + // Initialize circuit presets + circuitPresetMap = store.NewBucket[string, *circuitConfig](func(k string) *circuitConfig { return nil }) + initCircuitPresets(config.NewConfig(viper.New(), nil)) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup + cfg := config.NewConfig(viper.New(), nil) + tc.configSetup(cfg) + + // Execute + result := getCircuitConfigs(cfg) + + // Verify + if tc.expectedResult == nil { + assert.Nil(t, result, "Should return nil when circuit is disabled") + } else { + require.NotNil(t, result) + assert.Equal(t, tc.expectedResult.Timeout, result.Timeout) + assert.Equal(t, tc.expectedResult.MaxConcurrentRequests, result.MaxConcurrentRequests) + assert.Equal(t, tc.expectedResult.RequestVolumeThreshold, result.RequestVolumeThreshold) + assert.Equal(t, tc.expectedResult.SleepWindow, result.SleepWindow) + assert.Equal(t, tc.expectedResult.ErrorPercentThreshold, result.ErrorPercentThreshold) + } + }) + } +} + +func TestCircuitConfigToHystrixConfig(t *testing.T) { + cc := &circuitConfig{ + Name: "test-config", + Enabled: true, + Timeout: 1000, + MaxConcurrentRequests: 50, + ErrorPercentThreshold: 25, + RequestVolumeThreshold: 10, + SleepWindow: 2000, + } + + hystrixCfg := cc.toHystrixConfig() + + assert.Equal(t, cc.Timeout, hystrixCfg.Timeout) + assert.Equal(t, cc.MaxConcurrentRequests, hystrixCfg.MaxConcurrentRequests) + assert.Equal(t, cc.ErrorPercentThreshold, hystrixCfg.ErrorPercentThreshold) + assert.Equal(t, cc.RequestVolumeThreshold, hystrixCfg.RequestVolumeThreshold) + assert.Equal(t, cc.SleepWindow, hystrixCfg.SleepWindow) +} + +func TestShouldTreatStatusCodeAsFailure(t *testing.T) { + tests := []struct { + name string + config circuitConfig + statusCode int + expectedTreatment bool + }{ + { + name: "Default behavior with server error", + config: circuitConfig{ + StatusCodeConfig: statusCodeConfig{}, + }, + statusCode: http.StatusInternalServerError, + expectedTreatment: true, + }, + { + name: "Default behavior with client error", + config: circuitConfig{ + StatusCodeConfig: statusCodeConfig{}, + }, + statusCode: http.StatusBadRequest, + expectedTreatment: false, + }, + { + name: "Treat all error codes as failure", + config: circuitConfig{ + StatusCodeConfig: statusCodeConfig{ + TreatAllErrorCodesAsFailure: true, + }, + }, + statusCode: http.StatusBadRequest, + expectedTreatment: true, + }, + { + name: "Specific status code as failure", + config: circuitConfig{ + StatusCodeConfig: statusCodeConfig{ + SpecificStatusCodes: []int{http.StatusBadRequest, http.StatusNotFound}, + }, + }, + statusCode: http.StatusNotFound, + expectedTreatment: true, + }, + { + name: "Ignore specific status code", + config: circuitConfig{ + StatusCodeConfig: statusCodeConfig{ + IgnoreStatusCodes: []int{http.StatusInternalServerError}, + }, + }, + statusCode: http.StatusInternalServerError, + expectedTreatment: false, + }, + { + name: "Ignore takes precedence over specific", + config: circuitConfig{ + StatusCodeConfig: statusCodeConfig{ + IgnoreStatusCodes: []int{http.StatusNotFound}, + SpecificStatusCodes: []int{http.StatusNotFound}, + }, + }, + statusCode: http.StatusNotFound, + expectedTreatment: false, + }, + { + name: "Ignore takes precedence over treat all as failure", + config: circuitConfig{ + StatusCodeConfig: statusCodeConfig{ + IgnoreStatusCodes: []int{http.StatusBadRequest}, + TreatAllErrorCodesAsFailure: true, + }, + }, + statusCode: http.StatusBadRequest, + expectedTreatment: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := tt.config.shouldTreatStatusCodeAsFailure(tt.statusCode) + assert.Equal(t, tt.expectedTreatment, actual) + }) + } +} + +func TestDefaultCircuitConfigs(t *testing.T) { + // Test default values + assert.Equal(t, "default", defaultCircuitConfig.Name) + assert.Equal(t, true, defaultCircuitConfig.Enabled) + assert.Equal(t, 5000, defaultCircuitConfig.Timeout) + assert.Equal(t, 100, defaultCircuitConfig.MaxConcurrentRequests) + assert.Equal(t, 50, defaultCircuitConfig.ErrorPercentThreshold) + assert.Equal(t, 20, defaultCircuitConfig.RequestVolumeThreshold) + assert.Equal(t, 5000, defaultCircuitConfig.SleepWindow) + + // Test aggressive values + assert.Equal(t, "aggressive", aggressiveCircuitConfig.Name) + assert.Equal(t, 2000, aggressiveCircuitConfig.Timeout) + assert.Equal(t, 50, aggressiveCircuitConfig.MaxConcurrentRequests) + assert.Equal(t, 25, aggressiveCircuitConfig.ErrorPercentThreshold) + assert.Equal(t, 10, aggressiveCircuitConfig.RequestVolumeThreshold) + assert.Equal(t, 3000, aggressiveCircuitConfig.SleepWindow) + + // Test relaxed values + assert.Equal(t, "relaxed", relaxedCircuitConfig.Name) + assert.Equal(t, 10000, relaxedCircuitConfig.Timeout) + assert.Equal(t, 200, relaxedCircuitConfig.MaxConcurrentRequests) + assert.Equal(t, 75, relaxedCircuitConfig.ErrorPercentThreshold) + assert.Equal(t, 40, relaxedCircuitConfig.RequestVolumeThreshold) + assert.Equal(t, 7000, relaxedCircuitConfig.SleepWindow) +} + +func TestCircuitPresetBucket(t *testing.T) { + // Reset bucket before testing + circuitPresetMap = store.NewBucket[string, *circuitConfig](func(k string) *circuitConfig { return nil }) + + // Test default function returns nil for non-existent keys + result := circuitPresetMap.Get("non-existent") + assert.Nil(t, result) + + // Test setting and getting values + testPreset := &circuitConfig{Name: "test-preset"} + circuitPresetMap.Set("test-preset", testPreset) + + result = circuitPresetMap.Get("test-preset") + assert.Equal(t, testPreset, result) + + // Test Has functionality + assert.True(t, circuitPresetMap.Has("test-preset")) + assert.False(t, circuitPresetMap.Has("non-existent")) + + // Test Remove functionality + circuitPresetMap.Remove("test-preset") + assert.False(t, circuitPresetMap.Has("test-preset")) + assert.Nil(t, circuitPresetMap.Get("test-preset")) +} diff --git a/modules/client/client.go b/modules/client/client.go index 6d67cd7..b5c5133 100644 --- a/modules/client/client.go +++ b/modules/client/client.go @@ -2,7 +2,6 @@ package client import ( "context" - "github.com/Trendyol/chaki/config" "github.com/go-resty/resty/v2" ) @@ -18,6 +17,8 @@ type Factory struct { } func NewFactory(cfg *config.Config, wrappers []DriverWrapper) *Factory { + initCircuitPresets(cfg) + initRetryPresets(cfg) return &Factory{ cfg: cfg, baseWrappers: wrappers, @@ -34,16 +35,24 @@ func (f *Factory) Get(name string, opts ...Option) *Base { opt.Apply(cOpts) } + clientCfg := f.cfg.Of("client").Of(name) + return &Base{ - driver: newDriverBuilder(f.cfg.Of("client").Of(name)). + name: name, + driver: newDriverBuilder(clientCfg). AddErrDecoder(cOpts.errDecoder). AddUpdaters(f.baseWrappers...). AddUpdaters(cOpts.driverWrappers...). + SetRetry(getRetryConfigs(clientCfg)). + SetCircuit(getCircuitConfigs(clientCfg)). build(), - name: name, } } -func (r *Base) Request(ctx context.Context) *resty.Request { - return r.driver.R().SetContext(ctx) +func (b *Base) Request(ctx context.Context) *resty.Request { + return b.driver.R().SetContext(ctx) +} + +func (b *Base) RequestWithCommand(ctx context.Context, command string) *resty.Request { + return b.driver.R().SetContext(context.WithValue(ctx, circuitCommandKey, command)) } diff --git a/modules/client/client_test.go b/modules/client/client_test.go new file mode 100644 index 0000000..b39ce90 --- /dev/null +++ b/modules/client/client_test.go @@ -0,0 +1,155 @@ +package client + +import ( + "context" + "testing" + + "github.com/go-resty/resty/v2" + "github.com/stretchr/testify/assert" +) + +// TestNewFactory verifies that the factory is created correctly with the provided configuration and wrappers +func TestNewFactory(t *testing.T) { + // Setup + cfg := clientTestConfig() + wrapper, _ := createDriverWrapper("X-Test", "test") + wrappers := []DriverWrapper{wrapper} + + // Execute + factory := NewFactory(cfg, wrappers) + + // Verify + assert.NotNil(t, factory) + assert.Equal(t, cfg, factory.cfg) + assert.Equal(t, wrappers, factory.baseWrappers) +} + +// TestFactory_Get verifies that the factory creates clients correctly with various options +func TestFactory_Get(t *testing.T) { + // Define test cases + testCases := []struct { + name string + setupOptions []Option + verifyFunction func(t *testing.T, client *Base) + }{ + { + name: "creates client with default options", + setupOptions: nil, + verifyFunction: func(t *testing.T, client *Base) { + assert.Equal(t, "testclient", client.name) + assert.NotNil(t, client.driver) + assert.Equal(t, "https://example.com", client.driver.HostURL) + }, + }, + { + name: "applies options correctly", + setupOptions: func() []Option { + customErrDecoder, _ := createTestErrDecoder() + customWrapper, _ := createDriverWrapper("X-Custom", "value") + return []Option{ + WithErrDecoder(customErrDecoder), + WithDriverWrappers(customWrapper), + } + }(), + verifyFunction: func(t *testing.T, client *Base) { + assert.NotNil(t, client) + // Additional verification can be done here if needed + }, + }, + } + + // Run test cases + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup + cfg := clientTestConfig() + factory := NewFactory(cfg, nil) + + // Execute + client := factory.Get("testclient", tc.setupOptions...) + + // Verify + assert.NotNil(t, client) + tc.verifyFunction(t, client) + }) + } + + // Additional test for error decoder functionality + t.Run("error decoder is applied and works", func(t *testing.T) { + // Setup + cfg := clientTestConfig() + factory := NewFactory(cfg, nil) + + customErrDecoder, decoderCalled := createTestErrDecoder() + + // Execute + client := factory.Get("testclient", WithErrDecoder(customErrDecoder)) + + // Create a dummy response to test the error decoder + req := client.driver.R().SetContext(context.Background()) + dummyRes := &resty.Response{ + Request: req, + } + // Manually invoke the error handler + err := customErrDecoder(context.Background(), dummyRes) + + // Verify + assert.NoError(t, err) + assert.True(t, *decoderCalled, "Custom error decoder should be called") + }) + + // Additional test for wrapper functionality + t.Run("wrapper is applied and works", func(t *testing.T) { + // Setup + cfg := clientTestConfig() + factory := NewFactory(cfg, nil) + + customWrapper, wrapperCalled := createDriverWrapper("X-Custom", "value") + + // Execute + client := factory.Get("testclient", WithDriverWrappers(customWrapper)) + + // Verify + assert.NotNil(t, client) + assert.True(t, *wrapperCalled, "Custom wrapper should be applied") + // Verify the client has the expected header + assert.Equal(t, "value", client.driver.Header.Get("X-Custom")) + }) +} + +// TestBase_Request verifies that the Request method correctly creates a request with the provided context +func TestBase_Request(t *testing.T) { + // Setup + base := &Base{ + name: "testclient", + driver: resty.New(), + } + ctx := context.Background() + + // Execute + req := base.Request(ctx) + + // Verify + assert.NotNil(t, req) + assert.Equal(t, ctx, req.Context()) +} + +// TestBase_RequestWithCommand verifies that the RequestWithCommand method correctly creates a request +// with the provided context and command +func TestBase_RequestWithCommand(t *testing.T) { + // Setup + base := &Base{ + name: "testclient", + driver: resty.New(), + } + ctx := context.Background() + command := "test-command" + + // Execute + req := base.RequestWithCommand(ctx, command) + + // Verify + assert.NotNil(t, req) + // Check that the command was added to the context + assert.Equal(t, command, req.Context().Value(circuitCommandKey)) +} diff --git a/modules/client/driver.go b/modules/client/driver.go index bffb8a4..1ad43e1 100644 --- a/modules/client/driver.go +++ b/modules/client/driver.go @@ -1,17 +1,21 @@ package client import ( + "net/http" + "github.com/Trendyol/chaki/config" "github.com/Trendyol/chaki/logger" + "github.com/Trendyol/chaki/modules/client/common" "github.com/go-resty/resty/v2" "go.uber.org/zap" ) type driverBuilder struct { - cfg *config.Config - eh ErrDecoder - d *resty.Client - updaters []DriverWrapper + cfg *config.Config + eh ErrDecoder + d *resty.Client + updaters []DriverWrapper + rtWrappers []common.RoundTripperWrapper } func newDriverBuilder(cfg *config.Config) *driverBuilder { @@ -39,6 +43,35 @@ func (b *driverBuilder) AddUpdaters(wrappers ...DriverWrapper) *driverBuilder { return b } +func (b *driverBuilder) AddRoundTripperWrappers(wrappers ...common.RoundTripperWrapper) *driverBuilder { + b.rtWrappers = append(b.rtWrappers, wrappers...) + return b +} + +func (b *driverBuilder) SetRetry(retryConfig *retryConfig) *driverBuilder { + if retryConfig == nil { + return b + } + + b.rtWrappers = append(b.rtWrappers, func(rt http.RoundTripper) http.RoundTripper { + return newRetryRoundTripper(rt, retryConfig) + }) + + return b +} + +func (b *driverBuilder) SetCircuit(circuitConfig *circuitConfig) *driverBuilder { + if circuitConfig == nil { + return b + } + + b.rtWrappers = append(b.rtWrappers, func(rt http.RoundTripper) http.RoundTripper { + return newCircuitRoundTripper(rt, circuitConfig) + }) + + return b +} + func (b *driverBuilder) build() *resty.Client { if b.cfg.GetBool("logging") { b.useLogging() @@ -48,12 +81,23 @@ func (b *driverBuilder) build() *resty.Client { b.d = upd(b.d) } + b.d.SetTransport(b.buildRoundTripper()) + b.d.OnAfterResponse(func(c *resty.Client, r *resty.Response) error { return b.eh(r.Request.Context(), r) }) return b.d } +func (b *driverBuilder) buildRoundTripper() http.RoundTripper { + rt := b.d.GetClient().Transport + for _, wr := range b.rtWrappers { + rt = wr(rt) + } + + return rt +} + func (b *driverBuilder) useLogging() { b.d.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error { logger.From(r.Context()).Info( @@ -83,4 +127,7 @@ func setDefaults(cfg *config.Config) { cfg.SetDefault("timeout", "5s") cfg.SetDefault("debug", false) cfg.SetDefault("logging", false) + + setDefaultCircuitConfigs(cfg) + setDefaultRetryConfigs(cfg) } diff --git a/modules/client/driver_test.go b/modules/client/driver_test.go new file mode 100644 index 0000000..8ec0f5e --- /dev/null +++ b/modules/client/driver_test.go @@ -0,0 +1,382 @@ +package client + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/go-resty/resty/v2" + "github.com/stretchr/testify/assert" +) + +// TestNewDriverBuilder verifies that a new driver builder is created correctly with the provided configuration +func TestNewDriverBuilder(t *testing.T) { + // Setup + cfg := driverTestConfig() + + // Execute + builder := newDriverBuilder(cfg) + + // Verify + assert.NotNil(t, builder) + assert.Equal(t, cfg, builder.cfg) + assert.NotNil(t, builder.d) + assert.Equal(t, "https://example.com", builder.d.HostURL) + assert.Equal(t, 1*time.Second, builder.d.GetClient().Timeout) + assert.False(t, builder.d.Debug) +} + +// TestDriverBuilder_AddErrDecoder verifies that error decoders are added correctly to the driver builder +func TestDriverBuilder_AddErrDecoder(t *testing.T) { + // Setup + cfg := driverTestConfig() + builder := newDriverBuilder(cfg) + errDecoder, decoderCalled := createTestErrDecoder() + + // Execute + result := builder.AddErrDecoder(errDecoder) + + // Verify + assert.NotNil(t, result) + // Don't compare functions directly, test that it's set to something + assert.NotNil(t, builder.eh) + assert.Same(t, builder, result) // Should return itself for chaining + + // Verify the decoder works by calling it + dummyRes := &resty.Response{ + Request: resty.New().R().SetContext(context.Background()), + } + err := builder.eh(context.Background(), dummyRes) + assert.NoError(t, err) + assert.True(t, *decoderCalled, "Error decoder should be called") +} + +// TestDriverBuilder_AddUpdaters verifies that driver updaters are added correctly to the driver builder +func TestDriverBuilder_AddUpdaters(t *testing.T) { + // Setup + cfg := driverTestConfig() + builder := newDriverBuilder(cfg) + + wrapper1, _ := createDriverWrapper("X-Test-1", "value1") + wrapper2, _ := createDriverWrapper("X-Test-2", "value2") + + // Execute + result := builder.AddUpdaters(wrapper1, wrapper2) + + // Verify + assert.NotNil(t, result) + assert.Len(t, builder.updaters, 2) + assert.Same(t, builder, result) // Should return itself for chaining +} + +// TestDriverBuilder_AddRoundTripperWrappers verifies that round tripper wrappers are added correctly to the driver builder +func TestDriverBuilder_AddRoundTripperWrappers(t *testing.T) { + // Setup + cfg := driverTestConfig() + builder := newDriverBuilder(cfg) + + wrapper1 := func(rt http.RoundTripper) http.RoundTripper { return rt } + wrapper2 := func(rt http.RoundTripper) http.RoundTripper { return rt } + + // Execute + result := builder.AddRoundTripperWrappers(wrapper1, wrapper2) + + // Verify + assert.NotNil(t, result) + assert.Len(t, builder.rtWrappers, 2) + assert.Same(t, builder, result) // Should return itself for chaining +} + +// TestDriverBuilder_SetRetry verifies that retry configuration is set correctly on the driver builder +func TestDriverBuilder_SetRetry(t *testing.T) { + t.Run("with retry configuration", func(t *testing.T) { + // Setup + cfg := driverTestConfig() + builder := newDriverBuilder(cfg) + retryConfig := &retryConfig{ + Name: "test-retry", + Count: 3, + Interval: 100 * time.Millisecond, + MaxDelay: 1 * time.Second, + DelayType: ConstantDelay, + } + + // Execute + result := builder.SetRetry(retryConfig) + + // Verify + assert.NotNil(t, result) + assert.Same(t, builder, result) // Should return itself for chaining + assert.Len(t, builder.rtWrappers, 1) // Should add one round tripper wrapper + + // Build the client to verify the round tripper is properly configured + client := builder.build() + assert.NotNil(t, client) + + // Verify the transport is wrapped with our RetryRoundTripper + transport := client.GetClient().Transport + _, ok := transport.(*RetryRoundTripper) + assert.True(t, ok, "Transport should be wrapped with RetryRoundTripper") + }) + + t.Run("with nil retry configuration", func(t *testing.T) { + // Setup + cfg := driverTestConfig() + builder := newDriverBuilder(cfg) + + // Execute + result := builder.SetRetry(nil) + + // Verify + assert.NotNil(t, result) + assert.Empty(t, builder.rtWrappers) + assert.Same(t, builder, result) // Should return itself for chaining + }) +} + +func TestDriverBuilder_SetCircuit(t *testing.T) { + t.Run("with circuit configuration", func(t *testing.T) { + // Setup + cfg := driverTestConfig() + builder := newDriverBuilder(cfg) + circuitConfig := &circuitConfig{ + Enabled: true, + Name: "test-circuit", + } + + // Execute + result := builder.SetCircuit(circuitConfig) + + // Verify + assert.NotNil(t, result) + assert.Len(t, builder.rtWrappers, 1) + assert.Same(t, builder, result) // Should return itself for chaining + }) + + t.Run("with nil circuit configuration", func(t *testing.T) { + // Setup + cfg := driverTestConfig() + builder := newDriverBuilder(cfg) + + // Execute + result := builder.SetCircuit(nil) + + // Verify + assert.NotNil(t, result) + assert.Empty(t, builder.rtWrappers) + assert.Same(t, builder, result) // Should return itself for chaining + }) +} + +func TestDriverBuilder_Build(t *testing.T) { + t.Run("applies all updaters", func(t *testing.T) { + // Setup + cfg := driverTestConfig() + builder := newDriverBuilder(cfg) + + wrapper1, wrapper1Called := createDriverWrapper("X-Test-1", "value1") + wrapper2, wrapper2Called := createDriverWrapper("X-Test-2", "value2") + + builder.updaters = append(builder.updaters, wrapper1, wrapper2) + + // Add error handler + errDecoder, errorHandlerCalled := createTestErrDecoder() + builder.eh = errDecoder + + // Execute + client := builder.build() + + // Verify + assert.NotNil(t, client) + assert.True(t, *wrapper1Called, "First wrapper should be applied") + assert.True(t, *wrapper2Called, "Second wrapper should be applied") + + // Create a dummy response to test the error handler + req := client.R().SetContext(context.Background()) + dummyRes := &resty.Response{ + Request: req, + } + + // Instead of directly accessing OnAfterResponseFuncs, create a request and test + // the error handler manually + err := errDecoder(context.Background(), dummyRes) + assert.NoError(t, err) + assert.True(t, *errorHandlerCalled, "Error handler should be called") + }) + + t.Run("configures round trippers in correct order", func(t *testing.T) { + // Setup + cfg := driverTestConfig() + cfg.Set("baseurl", "https://test-api.example.com") + cfg.Set("logging", false) + builder := newDriverBuilder(cfg) + + // Add round tripper wrappers in a specific order + executionOrder := make([]int, 0) + + builder.rtWrappers = append(builder.rtWrappers, + func(rt http.RoundTripper) http.RoundTripper { + executionOrder = append(executionOrder, 1) + return rt + }, + func(rt http.RoundTripper) http.RoundTripper { + executionOrder = append(executionOrder, 2) + return rt + }, + ) + + // Execute + builder.build() + + // Verify + assert.Equal(t, []int{1, 2}, executionOrder, "Round trippers should be applied in order") + }) + + t.Run("enables logging when configured", func(t *testing.T) { + // Setup + cfg := driverTestConfig() + cfg.Set("baseurl", "https://test-api.example.com") + cfg.Set("logging", true) + builder := newDriverBuilder(cfg) + + // Execute + client := builder.build() + + // Verify + assert.NotNil(t, client) + // Don't check internal fields directly, verify functionality instead + assert.NotNil(t, client) + }) +} + +func TestSetDefaults(t *testing.T) { + // Setup - use driverTestConfig but reset the values we want to test + cfg := driverTestConfig() + + // Reset the values we want to test defaults for + cfg.Set("timeout", nil) + cfg.Set("debug", nil) + cfg.Set("logging", nil) + + // Execute + setDefaults(cfg) + + // Verify + assert.Equal(t, "5s", cfg.GetString("timeout")) + assert.False(t, cfg.GetBool("debug")) + assert.False(t, cfg.GetBool("logging")) +} + +func TestDriverBuilder_Integration(t *testing.T) { + // Setup a test server with our standardHandler utility + server := mockServer(standardHandler(http.StatusOK, map[string]interface{}{ + "message": "success", + }, 0)) + defer server.Close() + + // Create config + cfg := driverTestConfig() + cfg.Set("baseurl", server.URL) + + // Create wrappers for testing + wrapper, wrapperCalled := createDriverWrapper("X-Test", "value") + errDecoder, errDecoderCalled := createTestErrDecoder() + + // Create the builder with all the features + builder := newDriverBuilder(cfg). + AddErrDecoder(errDecoder). + AddUpdaters(wrapper). + AddRoundTripperWrappers(func(rt http.RoundTripper) http.RoundTripper { + return rt + }) + + // Build the client + client := builder.build() + + // Make a test request + resp, err := client.R().Get("/") + + // Verify + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode()) + assert.True(t, *wrapperCalled, "Updater should be called") + // The error decoder will be called with a 200 status, so no error should be returned + assert.True(t, *errDecoderCalled, "Error decoder should be called") +} + +func TestDriverBuilder_WithMockRoundTripper(t *testing.T) { + // Setup + cfg := driverTestConfig() + cfg.Set("baseurl", "https://test-api.example.com") + builder := newDriverBuilder(cfg) + + // Create a mock round tripper with custom behavior + mockRT := &MockRoundTripper{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + return createSuccessResponse(req, `{"custom": "response"}`), nil + }, + } + + // Replace the default transport with our mock + builder.d.SetTransport(mockRT) + builder.eh = DefaultErrDecoder("testClient") + + // Add a round tripper wrapper that should wrap our mock + wrapperCalled := false + builder.AddRoundTripperWrappers(func(rt http.RoundTripper) http.RoundTripper { + wrapperCalled = true + // Verify that we're wrapping the MockRoundTripper + _, isMockRT := rt.(*MockRoundTripper) + assert.True(t, isMockRT, "Should be wrapping our MockRoundTripper") + return rt + }) + + // Build the client + client := builder.build() + + // Verify + assert.NotNil(t, client) + assert.True(t, wrapperCalled, "RoundTripper wrapper should be called") + + // Test HTTP request using our mock client + resp, err := client.R().Get("/test-endpoint") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode()) + assert.Contains(t, string(resp.Body()), "custom") + assert.Equal(t, 1, mockRT.RequestCount, "Should have made exactly one request") +} + +func TestDriverBuilder_UseLogging(t *testing.T) { + t.Run("registers and executes logging callbacks", func(t *testing.T) { + // Setup + cfg := driverTestConfig() + builder := newDriverBuilder(cfg) + + // Execute - call useLogging directly + builder.useLogging() + + // Create a test server that we'll use to verify the logging works + server := mockServer(standardHandler(http.StatusOK, map[string]interface{}{ + "response": "test", + }, 0)) + defer server.Close() + + // Configure the client to use our test server + builder.d.SetBaseURL(server.URL) + + // Make a test request with various parameters to test logging + req := builder.d.R(). + SetHeader("X-Test", "test-value"). + SetQueryParam("param", "value"). + SetBody(map[string]string{"key": "value"}) + + // Execute the request which should trigger both logging callbacks + resp, err := req.Execute("GET", "/test") + + // Verify + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode()) + }) +} diff --git a/modules/client/error.go b/modules/client/error.go index b468406..23aaffa 100644 --- a/modules/client/error.go +++ b/modules/client/error.go @@ -11,6 +11,16 @@ import ( type ErrDecoder func(context.Context, *resty.Response) error +func DefaultErrDecoder(name string) ErrDecoder { + return func(_ context.Context, res *resty.Response) error { + if res.IsSuccess() { + return nil + } + + return NewGenericClientError(name, res.StatusCode(), res.Body()) + } +} + type GenericClientError struct { ClientName string StatusCode int @@ -18,6 +28,23 @@ type GenericClientError struct { ParsedBody interface{} } +func NewGenericClientError(clientName string, statusCode int, rawBody []byte) GenericClientError { + apiErr := GenericClientError{ + ClientName: clientName, + StatusCode: statusCode, + RawBody: rawBody, + } + + var jsonBody interface{} + if err := json.Unmarshal(rawBody, &jsonBody); err == nil { + apiErr.ParsedBody = jsonBody + } else { + apiErr.ParsedBody = string(rawBody) + } + + return apiErr +} + func (e GenericClientError) Error() string { msg := fmt.Sprintf("Error on client %s (Status %d)", e.ClientName, e.StatusCode) if details := e.extractErrorDetails(); details != "" { @@ -26,6 +53,10 @@ func (e GenericClientError) Error() string { return msg } +func (e GenericClientError) Status() int { + return e.StatusCode +} + func (e GenericClientError) extractErrorDetails() string { var details []string @@ -53,26 +84,3 @@ func (e GenericClientError) extractErrorDetails() string { return strings.Join(details, "; ") } - -func DefaultErrDecoder(name string) ErrDecoder { - return func(_ context.Context, res *resty.Response) error { - if res.IsSuccess() { - return nil - } - - apiErr := GenericClientError{ - ClientName: name, - StatusCode: res.StatusCode(), - RawBody: res.Body(), - } - - var jsonBody interface{} - if err := json.Unmarshal(res.Body(), &jsonBody); err == nil { - apiErr.ParsedBody = jsonBody - } else { - apiErr.ParsedBody = string(res.Body()) - } - - return apiErr - } -} diff --git a/modules/client/error_filter.go b/modules/client/error_filter.go new file mode 100644 index 0000000..ab85b5c --- /dev/null +++ b/modules/client/error_filter.go @@ -0,0 +1,23 @@ +package client + +import ( + "context" +) + +// WIP +type errorFilterFunc func(error) (bool, error) + +func SetErrorFilter(ctx context.Context, filter errorFilterFunc) context.Context { + return context.WithValue(ctx, circuitErrFilterKey, filter) +} + +func getErrorFilterFunc(ctx context.Context) errorFilterFunc { + if fb, ok := ctx.Value(circuitErrFilterKey).(errorFilterFunc); ok { + return fb + } + return defaultErrorFilter +} + +func defaultErrorFilter(err error) (bool, error) { + return true, err +} diff --git a/modules/client/error_test.go b/modules/client/error_test.go new file mode 100644 index 0000000..3bf193f --- /dev/null +++ b/modules/client/error_test.go @@ -0,0 +1,61 @@ +package client + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestErrorResponse verifies that error responses are correctly created and handled +func TestErrorResponse(t *testing.T) { + // Setup + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com", nil) + + // Test cases for different error responses + testCases := []struct { + name string + statusCode int + body string + expectedStatus int + }{ + { + name: "bad request error", + statusCode: http.StatusBadRequest, + body: `{"error": "Bad Request"}`, + expectedStatus: http.StatusBadRequest, + }, + { + name: "unauthorized error", + statusCode: http.StatusUnauthorized, + body: `{"error": "Unauthorized"}`, + expectedStatus: http.StatusUnauthorized, + }, + { + name: "internal server error", + statusCode: http.StatusInternalServerError, + body: `{"error": "Internal Server Error"}`, + expectedStatus: http.StatusInternalServerError, + }, + { + name: "default error message", + statusCode: http.StatusServiceUnavailable, + body: "", // Empty body to test default error message + expectedStatus: http.StatusServiceUnavailable, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create error response using the utility function + resp := createErrorResponse(req, tc.statusCode, tc.body) + + // Verify the response + assert.NotNil(t, resp) + assert.Equal(t, tc.expectedStatus, resp.StatusCode) + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + assert.Equal(t, req, resp.Request) + }) + } +} diff --git a/modules/client/fallback.go b/modules/client/fallback.go new file mode 100644 index 0000000..9b02f71 --- /dev/null +++ b/modules/client/fallback.go @@ -0,0 +1,74 @@ +package client + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" +) + +type fallbackFunc func(context.Context, error) (any, error) + +func SetFallbackFunc(ctx context.Context, fb fallbackFunc) context.Context { + return context.WithValue(ctx, circuitFallbackKey, fb) +} + +type fallbackHandler struct { + fn func(context.Context, error) (interface{}, error) + resp *http.Response + executed bool +} + +func newOrDefaultFallbackHandler(ctx context.Context) *fallbackHandler { + h := &fallbackHandler{} + + if fn, ok := ctx.Value(circuitFallbackKey).(fallbackFunc); ok { + h.fn = fn + } else { + h.fn = defaultFallbackFn + } + + return h +} + +func defaultFallbackFn(_ context.Context, e error) (interface{}, error) { + return nil, e +} + +func (f *fallbackHandler) handle(ctx context.Context, e error) error { + resp, err := f.fn(ctx, e) + if err != nil { + return err + } + + body, contentLength, contentType, err := interfaceToReadCloserWithLength(resp) + if err != nil { + return err + } + + f.resp = &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: body, + Header: make(http.Header), + ContentLength: contentLength, + } + + if contentType != "" { + f.resp.Header.Set("Content-Type", contentType) + } else { + f.resp.Header.Set("Content-Type", "application/json") + } + + f.executed = true + return nil +} + +func interfaceToReadCloserWithLength(data interface{}) (io.ReadCloser, int64, string, error) { + b, err := json.Marshal(data) + if err != nil { + return nil, 0, "", err + } + return io.NopCloser(bytes.NewReader(b)), int64(len(b)), "application/json", nil +} diff --git a/modules/client/fallback_test.go b/modules/client/fallback_test.go new file mode 100644 index 0000000..fcf5cac --- /dev/null +++ b/modules/client/fallback_test.go @@ -0,0 +1,188 @@ +package client + +import ( + "context" + "errors" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetFallbackFunc(t *testing.T) { + // Arrange + ctx := context.Background() + var called bool + testFn := func(ctx context.Context, err error) (interface{}, error) { + called = true + return nil, nil + } + + // Act + ctxWithFallback := SetFallbackFunc(ctx, testFn) + + // Assert + assert.NotEqual(t, ctx, ctxWithFallback, "Context should be different after setting fallback") + + // Verify we can retrieve the function from context + fnFromCtx, ok := ctxWithFallback.Value(circuitFallbackKey).(fallbackFunc) + assert.True(t, ok, "Should be able to retrieve fallback function from context") + + // Call the retrieved function to verify it's our function + _, _ = fnFromCtx(ctx, nil) + assert.True(t, called, "Retrieved function should call our test function") +} + +func TestNewOrDefaultFallbackHandler(t *testing.T) { + t.Run("with custom fallback function", func(t *testing.T) { + // Arrange + testData := map[string]string{"key": "value"} + fn, called, _ := createTestFallbackFunc(testData) + ctx := SetFallbackFunc(context.Background(), fn) + + // Act + handler := newOrDefaultFallbackHandler(ctx) + + // Assert - Test function behavior rather than comparing functions directly + // Call both functions with same input and compare results + testErr := errors.New("test error") + result1, err1 := handler.fn(context.Background(), testErr) + result2, err2 := fn(context.Background(), testErr) + + assert.Equal(t, result1, result2, "Handler function should return same result as original") + assert.Equal(t, err1, err2, "Handler function should return same error as original") + assert.True(t, *called, "Function should be called") + }) + + t.Run("with default fallback function", func(t *testing.T) { + // Arrange + ctx := context.Background() + + // Act + handler := newOrDefaultFallbackHandler(ctx) + + // Execute to verify it's the default function + testError := errors.New("test error") + resp, err := handler.fn(ctx, testError) + + // Assert + assert.Nil(t, resp, "Default function should return nil response") + assert.Equal(t, testError, err, "Default function should return the original error") + }) +} + +func TestDefaultFallbackFn(t *testing.T) { + // Arrange + testError := errors.New("test error") + ctx := context.Background() + + // Act + resp, err := defaultFallbackFn(ctx, testError) + + // Assert + assert.Nil(t, resp, "Default fallback function should return nil response") + assert.Equal(t, testError, err, "Default fallback function should return original error") +} + +func TestFallbackHandler_Handle(t *testing.T) { + t.Run("successful fallback", func(t *testing.T) { + // Arrange + testData := map[string]string{"key": "value"} + fn, called, passedErr := createTestFallbackFunc(testData) + ctx := context.Background() + testError := errors.New("test error") + + handler := &fallbackHandler{fn: fn} + + // Act + err := handler.handle(ctx, testError) + + // Assert + assert.NoError(t, err, "Handle should not return error when fallback succeeds") + assert.True(t, *called, "Fallback function should be called") + assert.Equal(t, testError, *passedErr, "Original error should be passed to fallback") + assert.True(t, handler.executed, "Handler should mark execution as completed") + + // Verify response structure + require.NotNil(t, handler.resp, "Response should be created") + assert.Equal(t, http.StatusOK, handler.resp.StatusCode) + assert.Equal(t, "application/json", handler.resp.Header.Get("Content-Type")) + + // Verify response body + respBody, err := io.ReadAll(handler.resp.Body) + require.NoError(t, err) + assert.Contains(t, string(respBody), `"key":"value"`, "Response body should contain our test data") + }) + + t.Run("failing fallback", func(t *testing.T) { + // Arrange + fn, called := createFailingFallbackFunc() + ctx := context.Background() + testError := errors.New("test error") + + handler := &fallbackHandler{fn: fn} + + // Act + err := handler.handle(ctx, testError) + + // Assert + assert.Error(t, err, "Handle should return error when fallback fails") + assert.True(t, *called, "Fallback function should be called") + assert.Equal(t, testError, err, "Original error should be returned when fallback fails") + assert.False(t, handler.executed, "Handler should not mark execution as completed") + }) +} + +func TestInterfaceToReadCloserWithLength(t *testing.T) { + testCases := []struct { + name string + data interface{} + expectError bool + }{ + { + name: "map data", + data: map[string]string{"key": "value"}, + expectError: false, + }, + { + name: "slice data", + data: []string{"value1", "value2"}, + expectError: false, + }, + { + name: "struct data", + data: struct{ Name string }{"test"}, + expectError: false, + }, + { + name: "channel data (not marshallable)", + data: make(chan int), + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Act + rc, length, contentType, err := interfaceToReadCloserWithLength(tc.data) + + // Assert + if tc.expectError { + assert.Error(t, err, "Should return error for unmarshallable data") + return + } + + assert.NoError(t, err, "Should not return error for valid data") + assert.NotNil(t, rc, "Should return a non-nil ReadCloser") + assert.Greater(t, length, int64(0), "Content length should be greater than 0") + assert.Equal(t, "application/json", contentType, "Content type should be application/json") + + // Read the body to verify it's valid JSON + body, err := io.ReadAll(rc) + require.NoError(t, err) + assert.NotEmpty(t, body, "Body should not be empty") + }) + } +} diff --git a/modules/client/module.go b/modules/client/module.go index c240c6a..1cb9d84 100644 --- a/modules/client/module.go +++ b/modules/client/module.go @@ -1,12 +1,11 @@ package client import ( - "net/http" - "github.com/Trendyol/chaki/as" "github.com/Trendyol/chaki/module" "github.com/Trendyol/chaki/modules/client/common" "github.com/go-resty/resty/v2" + "net/http" ) var ( diff --git a/modules/client/retry.go b/modules/client/retry.go new file mode 100644 index 0000000..77cea60 --- /dev/null +++ b/modules/client/retry.go @@ -0,0 +1,126 @@ +package client + +import ( + "time" + + "github.com/Trendyol/chaki/config" + "github.com/Trendyol/chaki/util/store" +) + +type ( + DelayType string + + retryConfig struct { + Name string `json:"name"` + Count int `json:"count"` + Interval time.Duration `json:"interval"` + MaxDelay time.Duration `json:"maxDelay"` + DelayType DelayType `json:"delayType"` + } +) + +const ( + ConstantDelay DelayType = "constant" + ExponentialDelay DelayType = "exponential" +) + +var ( + retryPresetMap = store.NewBucket(func(k string) *retryConfig { return nil }) + + defaultRetryConfig = &retryConfig{ + Name: "default", + Count: 3, + Interval: 100 * time.Millisecond, + MaxDelay: 5 * time.Second, + DelayType: ConstantDelay, + } + exponentialRetryConfig = &retryConfig{ + Name: "exponential", + Count: 3, + Interval: 100 * time.Millisecond, + MaxDelay: 5 * time.Second, + DelayType: ExponentialDelay, + } + aggressiveRetryConfig = &retryConfig{ + Name: "aggressive", + Count: 7, + Interval: 50 * time.Millisecond, + MaxDelay: 2 * time.Second, + DelayType: ConstantDelay, + } + aggressiveExponentialRetryConfig = &retryConfig{ + Name: "aggressiveExponential", + Count: 7, + Interval: 50 * time.Millisecond, + MaxDelay: 2 * time.Second, + DelayType: ExponentialDelay, + } + relaxedRetryConfig = &retryConfig{ + Name: "relaxed", + Count: 2, + Interval: 500 * time.Millisecond, + MaxDelay: 2 * time.Second, + DelayType: ConstantDelay, + } + relaxedExponentialConfig = &retryConfig{ + Name: "relaxedExponential", + Count: 2, + Interval: 500 * time.Millisecond, + MaxDelay: 2 * time.Second, + DelayType: ExponentialDelay, + } +) + +func getRetryConfigs(cfg *config.Config) *retryConfig { + if !cfg.GetBool("retry.enabled") { + return nil + } + + preset := cfg.GetString("retry.preset") + switch preset { + case "custom": + rc, err := config.ToStruct[*retryConfig](cfg, "retry") + if err != nil { + panic(err) + } + return rc + default: + if rc := retryPresetMap.Get(preset); rc != nil { + return rc + } + + panic("unknown retry preset: " + preset) + } +} + +func setDefaultRetryConfigs(cfg *config.Config) { + cfg.SetDefault("retry.enabled", false) + cfg.SetDefault("retry.preset", "default") + cfg.SetDefault("retry.count", 3) + cfg.SetDefault("retry.interval", "100ms") + cfg.SetDefault("retry.maxDelay", "5s") + cfg.SetDefault("retry.delayType", "constant") +} + +func initRetryPresets(cfg *config.Config) { + predefinedRetryPresets := []*retryConfig{ + defaultRetryConfig, + exponentialRetryConfig, + aggressiveRetryConfig, + aggressiveExponentialRetryConfig, + relaxedRetryConfig, + relaxedExponentialConfig, + } + + for _, rc := range predefinedRetryPresets { + retryPresetMap.Set(rc.Name, rc) + } + + userPresets, err := config.ToStruct[[]*retryConfig](cfg, "client.retryPresets") + if err != nil { + panic(err) + } + for _, rc := range userPresets { + retryPresetMap.Set(rc.Name, rc) + } +} diff --git a/modules/client/retry_rt.go b/modules/client/retry_rt.go new file mode 100644 index 0000000..1227adb --- /dev/null +++ b/modules/client/retry_rt.go @@ -0,0 +1,50 @@ +package client + +import ( + "math" + "math/rand" + "net/http" + "time" +) + +type RetryRoundTripper struct { + next http.RoundTripper + cfg *retryConfig +} + +func newRetryRoundTripper(next http.RoundTripper, cfg *retryConfig) http.RoundTripper { + return &RetryRoundTripper{ + next: next, + cfg: cfg, + } +} + +func (r *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := r.next.RoundTrip(req) + if r.cfg == nil || r.cfg.Count == 0 { + return resp, err + } + + delay := r.cfg.Interval + for i := 0; i < r.cfg.Count && err != nil; i++ { + if r.cfg.DelayType == ExponentialDelay { + exponentialDelay := delay * time.Duration(math.Pow(2, float64(i))) + + randFloat := rand.New(rand.NewSource(time.Now().UnixNano())).Float64() + jitter := time.Duration(randFloat * float64(r.cfg.Interval)) + delay = exponentialDelay + jitter + if delay > r.cfg.MaxDelay { + delay = r.cfg.MaxDelay + } + } + + select { + case <-req.Context().Done(): + return nil, req.Context().Err() + case <-time.After(delay): + resp, err = r.next.RoundTrip(req) + } + } + + return resp, err +} diff --git a/modules/client/retry_rt_test.go b/modules/client/retry_rt_test.go new file mode 100644 index 0000000..29c93be --- /dev/null +++ b/modules/client/retry_rt_test.go @@ -0,0 +1,249 @@ +package client + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type MockRetryRoundTripper struct { + callCount int + responses []*http.Response + errors []error + requestDelay time.Duration +} + +func (m *MockRetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if m.requestDelay > 0 { + time.Sleep(m.requestDelay) + } + + if m.callCount < len(m.responses) { + resp := m.responses[m.callCount] + err := m.errors[m.callCount] + m.callCount++ + return resp, err + } + + return nil, errors.New("unexpected call to RoundTrip") +} + +func TestNewRetryRoundTripper(t *testing.T) { + // Arrange + mockRT := &MockRetryRoundTripper{} + cfg := &retryConfig{ + Name: "test", + Count: 3, + Interval: 100 * time.Millisecond, + MaxDelay: 1 * time.Second, + DelayType: ConstantDelay, + } + + // Act + rt := newRetryRoundTripper(mockRT, cfg) + + // Assert + assert.NotNil(t, rt) + retryRT, ok := rt.(*RetryRoundTripper) + assert.True(t, ok) + assert.Equal(t, mockRT, retryRT.next) + assert.Equal(t, cfg, retryRT.cfg) +} + +func TestRetryRoundTripper_RoundTrip(t *testing.T) { + testCases := []struct { + name string + config *retryConfig + responses []*http.Response + errors []error + expectedCalls int + expectedError bool + timeout time.Duration + }{ + { + name: "success on first try", + config: &retryConfig{ + Count: 3, + Interval: 10 * time.Millisecond, + DelayType: ConstantDelay, + }, + responses: []*http.Response{{StatusCode: 200}}, + errors: []error{nil}, + expectedCalls: 1, + expectedError: false, + }, + { + name: "success after retries with constant delay", + config: &retryConfig{ + Count: 3, + Interval: 10 * time.Millisecond, + DelayType: ConstantDelay, + }, + responses: []*http.Response{nil, nil, {StatusCode: 200}}, + errors: []error{errors.New("error1"), errors.New("error2"), nil}, + expectedCalls: 3, + expectedError: false, + }, + { + name: "success after retries with exponential delay", + config: &retryConfig{ + Count: 3, + Interval: 10 * time.Millisecond, + MaxDelay: 100 * time.Millisecond, + DelayType: ExponentialDelay, + }, + responses: []*http.Response{nil, nil, {StatusCode: 200}}, + errors: []error{errors.New("error1"), errors.New("error2"), nil}, + expectedCalls: 3, + expectedError: false, + }, + { + name: "failure after all retries", + config: &retryConfig{ + Count: 2, + Interval: 10 * time.Millisecond, + DelayType: ConstantDelay, + }, + responses: []*http.Response{nil, nil, nil}, + errors: []error{errors.New("error1"), errors.New("error2"), errors.New("error3")}, + expectedCalls: 3, + expectedError: true, + }, + { + name: "no retry when config is nil", + config: nil, + responses: []*http.Response{nil}, + errors: []error{errors.New("error")}, + expectedCalls: 1, + expectedError: true, + }, + { + name: "no retry when count is 0", + config: &retryConfig{ + Count: 0, + Interval: 10 * time.Millisecond, + DelayType: ConstantDelay, + }, + responses: []*http.Response{nil}, + errors: []error{errors.New("error")}, + expectedCalls: 1, + expectedError: true, + }, + { + name: "respect context cancellation", + config: &retryConfig{ + Count: 3, + Interval: 100 * time.Millisecond, + DelayType: ConstantDelay, + }, + responses: []*http.Response{nil}, + errors: []error{errors.New("error")}, + expectedCalls: 1, + expectedError: true, + timeout: 50 * time.Millisecond, + }, + { + name: "respect max delay with exponential backoff", + config: &retryConfig{ + Count: 3, + Interval: 10 * time.Millisecond, + MaxDelay: 15 * time.Millisecond, + DelayType: ExponentialDelay, + }, + responses: []*http.Response{nil, nil, {StatusCode: 200}}, + errors: []error{errors.New("error1"), errors.New("error2"), nil}, + expectedCalls: 3, + expectedError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Arrange + mockRT := &MockRetryRoundTripper{ + responses: tc.responses, + errors: tc.errors, + } + + rt := &RetryRoundTripper{ + next: mockRT, + cfg: tc.config, + } + + ctx := context.Background() + if tc.timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, tc.timeout) + defer cancel() + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil) + require.NoError(t, err) + + // Act + resp, err := rt.RoundTrip(req) + + // Assert + assert.Equal(t, tc.expectedCalls, mockRT.callCount) + if tc.expectedError { + assert.Error(t, err) + if tc.timeout > 0 { + assert.ErrorIs(t, err, context.DeadlineExceeded) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 200, resp.StatusCode) + } + }) + } +} + +func TestRetryRoundTripper_ExponentialBackoff(t *testing.T) { + // Arrange + cfg := &retryConfig{ + Count: 3, + Interval: 10 * time.Millisecond, + MaxDelay: 50 * time.Millisecond, + DelayType: ExponentialDelay, + } + + mockRT := &MockRetryRoundTripper{ + responses: []*http.Response{nil, nil, {StatusCode: 200}}, + errors: []error{errors.New("error1"), errors.New("error2"), nil}, + requestDelay: 1 * time.Millisecond, // Small delay to make timing more realistic + } + + rt := &RetryRoundTripper{ + next: mockRT, + cfg: cfg, + } + + req, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + require.NoError(t, err) + + // Act + start := time.Now() + resp, err := rt.RoundTrip(req) + duration := time.Since(start) + + // Assert + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 3, mockRT.callCount) + + // The actual duration should be greater than the sum of the minimum delays + // First retry: 10ms + // Second retry: 20ms + jitter + minExpectedDuration := 30 * time.Millisecond + assert.Greater(t, duration, minExpectedDuration) + + // But should not exceed maximum delay plus some buffer for execution time + maxExpectedDuration := cfg.MaxDelay + (100 * time.Millisecond) + assert.Less(t, duration, maxExpectedDuration) +} diff --git a/modules/client/retry_test.go b/modules/client/retry_test.go new file mode 100644 index 0000000..801a746 --- /dev/null +++ b/modules/client/retry_test.go @@ -0,0 +1,267 @@ +package client + +import ( + "testing" + "time" + + "github.com/Trendyol/chaki/config" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestDelayType verifies that delay type string representations are correct +func TestDelayType(t *testing.T) { + tests := []struct { + name string + delayType DelayType + expectedString string + }{ + { + name: "constant delay type", + delayType: ConstantDelay, + expectedString: "constant", + }, + { + name: "exponential delay type", + delayType: ExponentialDelay, + expectedString: "exponential", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expectedString, string(tt.delayType)) + }) + } +} + +// TestRetryPresets verifies that predefined retry configurations have the expected values +func TestRetryPresets(t *testing.T) { + tests := []struct { + name string + preset *retryConfig + validate func(*testing.T, *retryConfig) + }{ + { + name: "default preset", + preset: defaultRetryConfig, + validate: func(t *testing.T, cfg *retryConfig) { + assert.Equal(t, "default", cfg.Name) + assert.Equal(t, 3, cfg.Count) + assert.Equal(t, 100*time.Millisecond, cfg.Interval) + assert.Equal(t, 5*time.Second, cfg.MaxDelay) + assert.Equal(t, ConstantDelay, cfg.DelayType) + }, + }, + { + name: "exponential preset", + preset: exponentialRetryConfig, + validate: func(t *testing.T, cfg *retryConfig) { + assert.Equal(t, "exponential", cfg.Name) + assert.Equal(t, 3, cfg.Count) + assert.Equal(t, 100*time.Millisecond, cfg.Interval) + assert.Equal(t, 5*time.Second, cfg.MaxDelay) + assert.Equal(t, ExponentialDelay, cfg.DelayType) + }, + }, + { + name: "aggressive preset", + preset: aggressiveRetryConfig, + validate: func(t *testing.T, cfg *retryConfig) { + assert.Equal(t, "aggressive", cfg.Name) + assert.Equal(t, 7, cfg.Count) + assert.Equal(t, 50*time.Millisecond, cfg.Interval) + assert.Equal(t, 2*time.Second, cfg.MaxDelay) + assert.Equal(t, ConstantDelay, cfg.DelayType) + }, + }, + { + name: "aggressive exponential preset", + preset: aggressiveExponentialRetryConfig, + validate: func(t *testing.T, cfg *retryConfig) { + assert.Equal(t, "aggressiveExponential", cfg.Name) + assert.Equal(t, 7, cfg.Count) + assert.Equal(t, 50*time.Millisecond, cfg.Interval) + assert.Equal(t, 2*time.Second, cfg.MaxDelay) + assert.Equal(t, ExponentialDelay, cfg.DelayType) + }, + }, + { + name: "relaxed preset", + preset: relaxedRetryConfig, + validate: func(t *testing.T, cfg *retryConfig) { + assert.Equal(t, "relaxed", cfg.Name) + assert.Equal(t, 2, cfg.Count) + assert.Equal(t, 500*time.Millisecond, cfg.Interval) + assert.Equal(t, 2*time.Second, cfg.MaxDelay) + assert.Equal(t, ConstantDelay, cfg.DelayType) + }, + }, + { + name: "relaxed exponential preset", + preset: relaxedExponentialConfig, + validate: func(t *testing.T, cfg *retryConfig) { + assert.Equal(t, "relaxedExponential", cfg.Name) + assert.Equal(t, 2, cfg.Count) + assert.Equal(t, 500*time.Millisecond, cfg.Interval) + assert.Equal(t, 2*time.Second, cfg.MaxDelay) + assert.Equal(t, ExponentialDelay, cfg.DelayType) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Execute validation function + tt.validate(t, tt.preset) + }) + } +} + +func TestGetRetryConfigs(t *testing.T) { + t.Run("disabled retry", func(t *testing.T) { + v := viper.New() + cfg := config.NewConfig(v, nil) + cfg.Set("retry.enabled", false) + + result := getRetryConfigs(cfg) + assert.Nil(t, result) + }) + + t.Run("custom configuration", func(t *testing.T) { + v := viper.New() + cfg := config.NewConfig(v, nil) + cfg.Set("retry.enabled", true) + cfg.Set("retry.preset", "custom") + cfg.Set("retry.name", "custom-retry") + cfg.Set("retry.count", 5) + cfg.Set("retry.interval", "200ms") + cfg.Set("retry.maxDelay", "3s") + cfg.Set("retry.delayType", "exponential") + + result := getRetryConfigs(cfg) + require.NotNil(t, result) + assert.Equal(t, "custom-retry", result.Name) + assert.Equal(t, 5, result.Count) + assert.Equal(t, 200*time.Millisecond, result.Interval) + assert.Equal(t, 3*time.Second, result.MaxDelay) + assert.Equal(t, ExponentialDelay, result.DelayType) + }) + + t.Run("preset configuration", func(t *testing.T) { + v := viper.New() + cfg := config.NewConfig(v, nil) + cfg.Set("retry.enabled", true) + cfg.Set("retry.preset", "default") + + result := getRetryConfigs(cfg) + require.NotNil(t, result) + assert.Equal(t, defaultRetryConfig, result) + }) + + t.Run("unknown preset", func(t *testing.T) { + v := viper.New() + cfg := config.NewConfig(v, nil) + cfg.Set("retry.enabled", true) + cfg.Set("retry.preset", "unknown") + + assert.Panics(t, func() { + getRetryConfigs(cfg) + }) + }) +} + +func TestSetDefaultRetryConfigs(t *testing.T) { + // Arrange + v := viper.New() + cfg := config.NewConfig(v, nil) + + // Act + setDefaultRetryConfigs(cfg) + + // Assert + assert.False(t, cfg.GetBool("retry.enabled")) + assert.Equal(t, "default", cfg.GetString("retry.preset")) + assert.Equal(t, 3, cfg.GetInt("retry.count")) + assert.Equal(t, "100ms", cfg.GetString("retry.interval")) + assert.Equal(t, "5s", cfg.GetString("retry.maxDelay")) + assert.Equal(t, "constant", cfg.GetString("retry.delayType")) +} + +func TestInitRetryPresets(t *testing.T) { + t.Run("predefined presets", func(t *testing.T) { + // Arrange + v := viper.New() + cfg := config.NewConfig(v, nil) + + // Act + initRetryPresets(cfg) + + // Assert - Check if all predefined presets are registered + presets := []string{"default", "exponential", "aggressive", "aggressiveExponential", "relaxed", "relaxedExponential"} + for _, preset := range presets { + rc := retryPresetMap.Get(preset) + assert.NotNil(t, rc, "Preset %s should be registered", preset) + } + }) + + t.Run("custom presets", func(t *testing.T) { + // Arrange + v := viper.New() + cfg := config.NewConfig(v, nil) + customPreset := &retryConfig{ + Name: "custom", + Count: 5, + Interval: 200 * time.Millisecond, + MaxDelay: 3 * time.Second, + DelayType: ExponentialDelay, + } + cfg.Set("client.retryPresets", []*retryConfig{customPreset}) + + // Act + initRetryPresets(cfg) + + // Assert + rc := retryPresetMap.Get("custom") + require.NotNil(t, rc) + assert.Equal(t, customPreset, rc) + }) + + t.Run("invalid custom presets", func(t *testing.T) { + // Arrange + v := viper.New() + cfg := config.NewConfig(v, nil) + cfg.Set("client.retryPresets", "invalid") // Invalid type + + // Act & Assert + assert.Panics(t, func() { + initRetryPresets(cfg) + }) + }) +} + +func TestDelayType_Values(t *testing.T) { + tests := []struct { + name string + delayType DelayType + expected string + }{ + { + name: "constant delay", + delayType: ConstantDelay, + expected: "constant", + }, + { + name: "exponential delay", + delayType: ExponentialDelay, + expected: "exponential", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, string(tc.delayType)) + }) + } +} diff --git a/modules/client/test_utils_test.go b/modules/client/test_utils_test.go new file mode 100644 index 0000000..8340199 --- /dev/null +++ b/modules/client/test_utils_test.go @@ -0,0 +1,238 @@ +// Package client provides HTTP client functionality with circuit breaker, retry, and fallback capabilities. +package client + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "time" + + "github.com/spf13/viper" + + "github.com/Trendyol/chaki/config" + "github.com/afex/hystrix-go/hystrix" + "github.com/go-resty/resty/v2" +) + +// --- Configuration Utilities --- + +// clientTestConfig creates a config with proper structure for client_test.go +// This sets up the config in a nested structure as expected by Factory.Get() +func clientTestConfig() *config.Config { + cfg := config.NewConfig(viper.New(), nil) + // Set up values as needed by the client package tests + // The client package expects a nested structure "client.{clientName}.baseurl" + cfg.Set("client.testclient.baseurl", "https://example.com") + cfg.Set("client.testclient.timeout", "1s") + cfg.Set("client.testclient.debug", false) + cfg.Set("client.testclient.logging", false) + + // Circuit breaker defaults + cfg.Set("client.testclient.circuit.enabled", false) + cfg.Set("client.testclient.circuit.preset", "default") + + // Retry defaults + cfg.Set("client.testclient.retry.enabled", false) + cfg.Set("client.testclient.retry.count", 3) + cfg.Set("client.testclient.retry.waitTime", "100ms") + cfg.Set("client.testclient.retry.maxWaitTime", "1s") + + return cfg +} + +// driverTestConfig creates a config with flat structure for driver_test.go +// This sets up the config in a flat structure as expected by driverBuilder +func driverTestConfig() *config.Config { + cfg := config.NewConfig(viper.New(), nil) + // Set up values as needed by the driver builder + // The driver builder expects values at the root level + cfg.Set("baseurl", "https://example.com") + cfg.Set("timeout", "1s") + cfg.Set("debug", false) + cfg.Set("logging", false) + + // Circuit breaker defaults + cfg.Set("circuit.enabled", false) + cfg.Set("circuit.preset", "default") + + // Retry defaults + cfg.Set("retry.enabled", false) + cfg.Set("retry.count", 3) + cfg.Set("retry.waitTime", "100ms") + cfg.Set("retry.maxWaitTime", "1s") + + return cfg +} + +// --- HTTP Server Mocking --- + +// mockServer creates a test HTTP server for integration testing +func mockServer(handler http.HandlerFunc) *httptest.Server { + return httptest.NewServer(handler) +} + +// standardHandler returns a common test handler that returns configurable responses +func standardHandler(statusCode int, body map[string]interface{}, responseDelay time.Duration) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if responseDelay > 0 { + time.Sleep(responseDelay) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + + if body != nil { + json.NewEncoder(w).Encode(body) + } + } +} + +// --- HTTP Response Utilities --- + +// createSuccessResponse generates a successful HTTP response +func createSuccessResponse(req *http.Request, body string) *http.Response { + if body == "" { + body = `{"success": true}` + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + Request: req, + } +} + +// createErrorResponse generates an error HTTP response +// Used in retry_rt_test.go and circuit_rt_test.go +// +//nolint:unused +func createErrorResponse(req *http.Request, statusCode int, body string) *http.Response { + if body == "" { + body = `{"error": "Something went wrong"}` + } + + return &http.Response{ + StatusCode: statusCode, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + Request: req, + } +} + +// --- Mock RoundTripper --- + +// MockRoundTripper allows tracking and simulating HTTP requests +type MockRoundTripper struct { + RoundTripFunc func(req *http.Request) (*http.Response, error) + RequestCount int + LastRequest *http.Request + RecordedRequests []*http.Request +} + +// RoundTrip implements the http.RoundTripper interface +func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + m.RequestCount++ + m.LastRequest = req + m.RecordedRequests = append(m.RecordedRequests, req.Clone(req.Context())) + + if m.RoundTripFunc != nil { + return m.RoundTripFunc(req) + } + + // Default success response + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"success": true}`)), + Header: http.Header{}, + }, nil +} + +// --- Error Handling Utilities --- + +// createTestErrDecoder creates an ErrDecoder that tracks its calls +func createTestErrDecoder() (ErrDecoder, *bool) { + called := false + decoder := func(ctx context.Context, resp *resty.Response) error { + called = true + if resp.StatusCode() >= 400 { + return NewGenericClientError("test", resp.StatusCode(), resp.Body()) + } + return nil + } + return decoder, &called +} + +// --- Driver Wrapper Utilities --- + +// createDriverWrapper creates a DriverWrapper that tracks its calls +func createDriverWrapper(header, value string) (DriverWrapper, *bool) { + called := false + wrapper := func(client *resty.Client) *resty.Client { + called = true + return client.SetHeader(header, value) + } + return wrapper, &called +} + +// --- Circuit Breaker Utilities --- + +// resetHystrix clears all Hystrix command configurations and metrics +// Useful for test isolation between different test cases +func resetHystrix() { + hystrix.Flush() +} + +// withCircuitCommand adds circuit command to context for tests +func withCircuitCommand(ctx context.Context, command string) context.Context { + return context.WithValue(ctx, circuitCommandKey, command) +} + +// triggerCircuitOpen forces a circuit to open for testing circuit breaker recovery +func triggerCircuitOpen(command string) { + // Configure the command with a low error threshold + hystrix.ConfigureCommand(command, hystrix.CommandConfig{ + Timeout: 10, + MaxConcurrentRequests: 1, + ErrorPercentThreshold: 1, // Will open with just one error + RequestVolumeThreshold: 1, // Only need one request to trigger + SleepWindow: 100, + }) + + // Run a function that will time out to force the circuit to open + _ = hystrix.Do(command, func() error { + time.Sleep(20 * time.Millisecond) // Longer than the timeout + return nil + }, nil) +} + +// --- Fallback Utilities --- + +// createTestFallbackFunc creates a fallback function for testing with tracked calls +func createTestFallbackFunc(responseData interface{}) (fallbackFunc, *bool, *error) { + called := false + var passedError error + + fn := func(ctx context.Context, err error) (interface{}, error) { + called = true + passedError = err + return responseData, nil + } + + return fn, &called, &passedError +} + +// createFailingFallbackFunc creates a fallback function that returns an error +func createFailingFallbackFunc() (fallbackFunc, *bool) { + called := false + + fn := func(ctx context.Context, err error) (interface{}, error) { + called = true + return nil, err // Simply returns the passed error + } + + return fn, &called +} diff --git a/modules/common/rand/rand.go b/modules/common/rand/rand.go new file mode 100644 index 0000000..cb93309 --- /dev/null +++ b/modules/common/rand/rand.go @@ -0,0 +1,44 @@ +package rand + +import ( + "math/rand" + "sync" + "time" +) + +var ( + instance *RandomGenerator + once sync.Once +) + +func GetInstance() *RandomGenerator { + once.Do(func() { + instance = &RandomGenerator{ + source: rand.New(rand.NewSource(time.Now().UnixNano())), + } + }) + return instance +} + +type RandomGenerator struct { + source *rand.Rand + mu sync.Mutex +} + +func (g *RandomGenerator) Int() int { + g.mu.Lock() + defer g.mu.Unlock() + return g.source.Int() +} + +func (g *RandomGenerator) Float64() float64 { + g.mu.Lock() + defer g.mu.Unlock() + return g.source.Float64() +} + +func (g *RandomGenerator) Intn(n int) int { + g.mu.Lock() + defer g.mu.Unlock() + return g.source.Intn(n) +} diff --git a/modules/newrelic/client/client.go b/modules/newrelic/client/client.go index 1b4a57e..0acc187 100644 --- a/modules/newrelic/client/client.go +++ b/modules/newrelic/client/client.go @@ -1,6 +1,7 @@ package client import ( + "github.com/Trendyol/chaki/modules/client/common" "net/http" "github.com/Trendyol/chaki/module" @@ -12,8 +13,10 @@ type httpRoundTripper struct { tr http.RoundTripper } -func newRoundTripper(tr http.RoundTripper) http.RoundTripper { - return &httpRoundTripper{tr} +func newRoundTripper() common.RoundTripperWrapper { + return func(tr http.RoundTripper) http.RoundTripper { + return &httpRoundTripper{tr} + } } func (t *httpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { diff --git a/modules/server/header.go b/modules/server/header.go new file mode 100644 index 0000000..c4dcfaf --- /dev/null +++ b/modules/server/header.go @@ -0,0 +1,10 @@ +package server + +import "context" + +type ctxKey string + +const HeadersKey ctxKey = "reqHeaders" + +func GetHeaders(ctx context.Context) { +} diff --git a/modules/server/middlewares/errorhandler.go b/modules/server/middlewares/errorhandler.go index 0010978..8f65d83 100644 --- a/modules/server/middlewares/errorhandler.go +++ b/modules/server/middlewares/errorhandler.go @@ -35,5 +35,9 @@ func getCodeFromErr(err error) int { return fErr.Code } + if sErr := new(interface{ Status() int }); errors.As(err, sErr) { + return (*sErr).Status() + } + return fiber.StatusInternalServerError } diff --git a/modules/server/response/response.go b/modules/server/response/response.go index d8371ec..0ed5d38 100644 --- a/modules/server/response/response.go +++ b/modules/server/response/response.go @@ -23,8 +23,8 @@ type Response[T any] struct { ValidationErrors []validation.FieldError `json:"validationErrors,omitempty"` } -func Success(data any) Response[any] { - return Response[any]{ +func Success[T any](data T) Response[T] { + return Response[T]{ Success: true, Data: data, } diff --git a/util/store/bucket.go b/util/store/bucket.go index 4054281..4327bda 100644 --- a/util/store/bucket.go +++ b/util/store/bucket.go @@ -1,24 +1,24 @@ package store -import "sync" +import ( + csmap "github.com/mhmtszr/concurrent-swiss-map" +) type Bucket[K comparable, T any] struct { - rw sync.RWMutex - m map[K]T + m *csmap.CsMap[K, T] onDefault func(K) T } func NewBucket[K comparable, T any](onDefault func(K) T) *Bucket[K, T] { + m := csmap.Create[K, T]() return &Bucket[K, T]{ - m: make(map[K]T), + m: m, onDefault: onDefault, } } func (b *Bucket[K, T]) Get(key K) T { - b.rw.RLock() - defer b.rw.RUnlock() - v, ok := b.m[key] + v, ok := b.m.Load(key) if !ok { return b.onDefault(key) } @@ -26,13 +26,13 @@ func (b *Bucket[K, T]) Get(key K) T { } func (b *Bucket[K, T]) Set(key K, t T) { - b.rw.Lock() - defer b.rw.Unlock() - b.m[key] = t + b.m.Store(key, t) } func (b *Bucket[K, T]) Remove(key K) { - b.rw.Lock() - defer b.rw.Unlock() - delete(b.m, key) + b.m.Delete(key) +} + +func (b *Bucket[K, T]) Has(key K) bool { + return b.m.Has(key) }