From c6efaeff17d421a6fc418cf848a7f986c437fea6 Mon Sep 17 00:00:00 2001 From: David Mair Spiess Date: Wed, 19 Feb 2025 11:46:43 +0100 Subject: [PATCH 1/3] feat: add recoverer middleware --- middlewarex/recoverer.go | 40 ++++++++++++++++++++++ middlewarex/recoverer_test.go | 62 +++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 middlewarex/recoverer.go create mode 100644 middlewarex/recoverer_test.go diff --git a/middlewarex/recoverer.go b/middlewarex/recoverer.go new file mode 100644 index 0000000..e43aac1 --- /dev/null +++ b/middlewarex/recoverer.go @@ -0,0 +1,40 @@ +package middlewarex + +import ( + "errors" + "net/http" +) + +type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error) + +func Recoverer(errorHandler ErrorHandler) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rvr := recover(); rvr != nil { + if rvr == http.ErrAbortHandler { + // Don’t recover http.ErrAbortHandler so the response to + // the client is aborted. + panic(rvr) + } + + var err error + switch x := rvr.(type) { + case string: + err = errors.New(x) + case error: + err = x + default: + err = errors.New("unknown panic") + } + + errorHandler(w, r, err) + } + }() + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) + } +} diff --git a/middlewarex/recoverer_test.go b/middlewarex/recoverer_test.go new file mode 100644 index 0000000..2b7d69b --- /dev/null +++ b/middlewarex/recoverer_test.go @@ -0,0 +1,62 @@ +package middlewarex + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRecoverer_WithPlainTextError(t *testing.T) { + plainTextErrorHandler := func(w http.ResponseWriter, r *http.Request, err error) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("internal server error: " + err.Error())) + } + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + }) + + h := Recoverer(plainTextErrorHandler)(next) + r := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + + res := w.Result() + assert.Equal(t, res.StatusCode, http.StatusInternalServerError) + assert.Equal(t, w.Body.String(), "internal server error: test panic") +} + +func TestRecoverer_WithJSONError(t *testing.T) { + type ErrorResponse struct { + Status int `json:"status"` + Message string `json:"message"` + } + + jsonErrorHandler := func(w http.ResponseWriter, r *http.Request, err error) { + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(ErrorResponse{ + Status: http.StatusInternalServerError, + Message: err.Error(), + }) + } + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + }) + + h := Recoverer(jsonErrorHandler)(next) + r := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + + res := w.Result() + var responseBody ErrorResponse + _ = json.NewDecoder(res.Body).Decode(&responseBody) + + assert.Equal(t, res.StatusCode, http.StatusInternalServerError) + assert.Equal(t, http.StatusInternalServerError, responseBody.Status) + assert.Equal(t, "test panic", responseBody.Message) +} From 59c08fa3ed9470758b2c5905b27dea96c1571b69 Mon Sep 17 00:00:00 2001 From: David Mair Spiess Date: Thu, 20 Feb 2025 10:28:59 +0100 Subject: [PATCH 2/3] Remove JSON error test --- middlewarex/recoverer_test.go | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/middlewarex/recoverer_test.go b/middlewarex/recoverer_test.go index 2b7d69b..1113253 100644 --- a/middlewarex/recoverer_test.go +++ b/middlewarex/recoverer_test.go @@ -1,7 +1,6 @@ package middlewarex import ( - "encoding/json" "net/http" "net/http/httptest" "testing" @@ -28,35 +27,3 @@ func TestRecoverer_WithPlainTextError(t *testing.T) { assert.Equal(t, res.StatusCode, http.StatusInternalServerError) assert.Equal(t, w.Body.String(), "internal server error: test panic") } - -func TestRecoverer_WithJSONError(t *testing.T) { - type ErrorResponse struct { - Status int `json:"status"` - Message string `json:"message"` - } - - jsonErrorHandler := func(w http.ResponseWriter, r *http.Request, err error) { - w.WriteHeader(http.StatusInternalServerError) - _ = json.NewEncoder(w).Encode(ErrorResponse{ - Status: http.StatusInternalServerError, - Message: err.Error(), - }) - } - - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - panic("test panic") - }) - - h := Recoverer(jsonErrorHandler)(next) - r := httptest.NewRequest(http.MethodGet, "/", nil) - w := httptest.NewRecorder() - h.ServeHTTP(w, r) - - res := w.Result() - var responseBody ErrorResponse - _ = json.NewDecoder(res.Body).Decode(&responseBody) - - assert.Equal(t, res.StatusCode, http.StatusInternalServerError) - assert.Equal(t, http.StatusInternalServerError, responseBody.Status) - assert.Equal(t, "test panic", responseBody.Message) -} From a9ab1eda4c3ec8377c292bc43169e7d05e031666 Mon Sep 17 00:00:00 2001 From: David Mair Spiess Date: Thu, 20 Feb 2025 16:12:33 +0100 Subject: [PATCH 3/3] Add abort handler test --- middlewarex/recoverer_test.go | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/middlewarex/recoverer_test.go b/middlewarex/recoverer_test.go index 1113253..917abbe 100644 --- a/middlewarex/recoverer_test.go +++ b/middlewarex/recoverer_test.go @@ -8,7 +8,25 @@ import ( "github.com/stretchr/testify/assert" ) -func TestRecoverer_WithPlainTextError(t *testing.T) { +func TestRecovererAbortHandler(t *testing.T) { + defer func() { + rcv := recover() + if rcv != http.ErrAbortHandler { + t.Fatalf("http.ErrAbortHandler should not be recovered") + } + }() + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic(http.ErrAbortHandler) + }) + + h := Recoverer(nil)(next) + r := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) +} + +func TestRecovererCustomErrorResponse(t *testing.T) { plainTextErrorHandler := func(w http.ResponseWriter, r *http.Request, err error) { w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write([]byte("internal server error: " + err.Error()))