Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type Config struct {
type ServerConfig struct {
Address string `yaml:"address"`
ShutdownGracePeriod *time.Duration `yaml:"shutdownGracePeriod"`
WithErrorHeader bool `yaml:"withErrorHeader"`
HTTPSConfig ServerHTTPSConfig `yaml:"https"`
}

Expand Down
11 changes: 11 additions & 0 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,17 @@ func TestSetDefaults(t *testing.T) {
t.Errorf("expected ShutdownGracePeriod 5s, got %v", *cfg.Server.ShutdownGracePeriod)
}
})

t.Run("WithErrorHeader default value", func(t *testing.T) {
t.Parallel()

configPath := setupTestFile(t, "https-disabled.yaml")
cfg := mustLoadConfig(t, configPath)

if cfg.Server.WithErrorHeader {
t.Error("expected WithErrorHeader to be false by default")
}
})
}

func mustLoadConfig(t *testing.T, configPath string) *config.Config {
Expand Down
108 changes: 61 additions & 47 deletions pkg/httpx/httpx_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,21 @@ import (
"errors"
"log"
"net/http"
"runtime/debug"
)

// WebError represents an HTTP error with an associated HTTP status code and an optional public message.
type WebError struct {
err error
httpStatus int
message string
withStackTrace bool
err error
httpStatus int
message string
}

// NewWebError creates a new WebError.
func NewWebError(err error, status int, message string) *WebError {
return &WebError{
err: err,
httpStatus: status,
message: message,
withStackTrace: true,
}
}

// NewWebErrorNoStack creates a new WebError without stack trace in logs.
func NewWebErrorNoStack(err error, status int, message string) *WebError {
return &WebError{
err: err,
httpStatus: status,
message: message,
withStackTrace: false,
err: err,
httpStatus: status,
message: message,
}
}

Expand Down Expand Up @@ -59,31 +46,48 @@ func (e *WebError) HTTPStatus() int { return e.httpStatus }
// Message returns the optional public message.
func (e *WebError) Message() string { return e.message }

// StackTrace returns whether the stack trace should be logged.
func (e *WebError) StackTrace() bool { return e.withStackTrace }

type statusCoder interface {
error
HTTPStatus() int
Message() string
}

type messageCarrier interface {
Message() string
type silentError interface {
error
Silent() bool
}

type SilentError struct {
err error
}

func NewSilentError(err error) *SilentError {
return &SilentError{err: err}
}

func (e *SilentError) Error() string {
if e.err != nil {
return e.err.Error()
}

return "silent error"
}

func (e *SilentError) Unwrap() error { return e.err }

func (e *SilentError) Silent() bool { return true }

// Compile-time check.
var (
_ statusCoder = (*WebError)(nil)
_ messageCarrier = (*WebError)(nil)
_ statusCoder = (*WebError)(nil)
_ silentError = (*SilentError)(nil)
)

type stackTracer interface {
StackTrace() bool
}

// ErrorSink returns a terminal handler that logs errors and writes appropriate HTTP responses.
// If logger is nil, log.Default() is used.
func ErrorSink(logger *log.Logger) func(WebHandler) http.Handler {
// If withErrorHeader is true, the error message is added to the X-Error-Message HTTP header.
// If the error implements silentError and Silent() returns true, it is completely ignored.
func ErrorSink(logger *log.Logger, withErrorHeader bool) func(WebHandler) http.Handler {
if logger == nil {
logger = log.Default()
}
Expand All @@ -95,28 +99,24 @@ func ErrorSink(logger *log.Logger) func(WebHandler) http.Handler {
return
}

status := http.StatusInternalServerError
msg := ""
withStackTrace := true
var se silentError
if errors.As(err, &se) && se.Silent() {
return
}

var sc statusCoder
if errors.As(err, &sc) {
status = sc.HTTPStatus()
status, msg := extractErrorInfo(err)

if mc, ok := sc.(messageCarrier); ok {
msg = mc.Message()
if withErrorHeader {
headerMsg := msg
if headerMsg == "" {
headerMsg = err.Error()
}
}

var st stackTracer
if errors.As(err, &st) {
withStackTrace = st.StackTrace()
responseWriter.Header().Set("X-Error-Message", headerMsg)
}

if status >= http.StatusInternalServerError && withStackTrace {
logger.Printf("[ERROR] %s %s: %v\nStack Trace:\n%s",
request.Method, request.URL.Path, err, debug.Stack(),
)
if status >= http.StatusInternalServerError {
logger.Printf("[ERROR] %s %s: %v", request.Method, request.URL.Path, err)
} else {
logger.Printf("[WARN] %s %s: %v", request.Method, request.URL.Path, err)
}
Expand All @@ -131,3 +131,17 @@ func ErrorSink(logger *log.Logger) func(WebHandler) http.Handler {
})
}
}

func extractErrorInfo(err error) (int, string) {
status := http.StatusInternalServerError
msg := ""

var sc statusCoder

if errors.As(err, &sc) {
status = sc.HTTPStatus()
msg = sc.Message()
}

return status, msg
}
118 changes: 118 additions & 0 deletions pkg/httpx/httpx_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package httpx_test

import (
"errors"
"net/http"
"net/http/httptest"
"testing"

"github.com/dkarczmarski/webcmd/pkg/httpx"
)

type customStatusError struct {
err error
status int
msg string
}

func (e *customStatusError) Error() string { return e.err.Error() }
func (e *customStatusError) HTTPStatus() int { return e.status }
func (e *customStatusError) Message() string { return e.msg }

func TestErrorSink(t *testing.T) {
t.Parallel()

tests := []struct {
name string
err error
withErrorHeader bool
expectedHeader string
expectedStatus int
expectedBody string
}{
{
name: "No error header, normal error (NOT statusCoder)",
err: errors.New("standard error"),
withErrorHeader: false,
expectedHeader: "",
expectedStatus: http.StatusInternalServerError,
expectedBody: "",
},
{
name: "With error header, normal error (NOT statusCoder)",
err: errors.New("standard error"),
withErrorHeader: true,
expectedHeader: "standard error",
expectedStatus: http.StatusInternalServerError,
expectedBody: "",
},
{
name: "With error header, WebError with message (statusCoder)",
err: httpx.NewWebError(errors.New("internal"), http.StatusBadRequest, "public message"),
withErrorHeader: true,
expectedHeader: "public message",
expectedStatus: http.StatusBadRequest,
expectedBody: "public message\n",
},
{
name: "With error header, WebError without message (statusCoder)",
err: httpx.NewWebError(errors.New("internal error message"), http.StatusBadRequest, ""),
withErrorHeader: true,
expectedHeader: "internal error message",
expectedStatus: http.StatusBadRequest,
expectedBody: "",
},
{
name: "Custom statusCoder implementation",
err: &customStatusError{
err: errors.New("custom error"),
status: http.StatusTeapot,
msg: "I am a teapot",
},
withErrorHeader: false,
expectedHeader: "",
expectedStatus: http.StatusTeapot,
expectedBody: "I am a teapot\n",
},
{
name: "SilentError should be ignored",
err: httpx.NewSilentError(errors.New("silent error")),
withErrorHeader: true,
expectedHeader: "",
expectedStatus: http.StatusOK,
expectedBody: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

handler := httpx.WebHandlerFunc(func(_ http.ResponseWriter, _ *http.Request) error {
return tt.err
})

sink := httpx.ErrorSink(nil, tt.withErrorHeader)
h := sink(handler)

req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
w := httptest.NewRecorder()

h.ServeHTTP(w, req)

if w.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
}

gotHeader := w.Header().Get("X-Error-Message")
if gotHeader != tt.expectedHeader {
t.Errorf("expected header X-Error-Message %q, got %q", tt.expectedHeader, gotHeader)
}

gotBody := w.Body.String()
if gotBody != tt.expectedBody {
t.Errorf("expected body %q, got %q", tt.expectedBody, gotBody)
}
})
}
}
10 changes: 5 additions & 5 deletions pkg/router/handlers/execution_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func runCommand(

gate, gateErr := registry.GetOrCreate(groupName, cmd.CallGate.Mode)
if gateErr != nil {
return httpx.NewWebErrorNoStack(
return httpx.NewWebError(
gateErr, http.StatusInternalServerError, fmt.Sprintf("callgate registry: %v", gateErr),
)
}
Expand Down Expand Up @@ -152,7 +152,7 @@ func extractParams(request *http.Request, cmd *config.URLCommand) (map[string]in

bodyBytes, err := io.ReadAll(request.Body)
if err != nil {
return nil, httpx.NewWebErrorNoStack(
return nil, httpx.NewWebError(
fmt.Errorf("failed to read request body: %w", err),
http.StatusInternalServerError,
"",
Expand All @@ -176,7 +176,7 @@ func buildCommand(
) (*cmdbuilder.Result, error) {
cmdResult, err := cmdbuilder.BuildCommand(template, params)
if err != nil {
return nil, httpx.NewWebErrorNoStack(
return nil, httpx.NewWebError(
fmt.Errorf("error building command: %w", err),
http.StatusInternalServerError,
"",
Expand Down Expand Up @@ -232,7 +232,7 @@ func prepareOutput(responseWriter http.ResponseWriter, outputType string) (io.Wr
async = true
case "stream":
if _, ok := responseWriter.(http.Flusher); !ok {
return nil, false, httpx.NewWebErrorNoStack(
return nil, false, httpx.NewWebError(
fmt.Errorf("streaming not supported: %w", ErrBadConfiguration),
http.StatusInternalServerError,
"",
Expand All @@ -250,7 +250,7 @@ func prepareOutput(responseWriter http.ResponseWriter, outputType string) (io.Wr

responseWriter.Header().Set("Content-Type", "text/plain; charset=utf-8")
default:
return nil, false, httpx.NewWebErrorNoStack(
return nil, false, httpx.NewWebError(
fmt.Errorf("%w: unknown output type %q", ErrBadConfiguration, outputType),
http.StatusInternalServerError,
"",
Expand Down
Loading