Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions middlewarex/recoverer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package middlewarex

import (
"errors"
"net/http"
)

type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error)

func Recoverer(errorHandler ErrorHandler) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rvr := recover(); rvr != nil {
if rvr == http.ErrAbortHandler {
// Don’t recover http.ErrAbortHandler so the response to
// the client is aborted.
panic(rvr)
}

var err error
switch x := rvr.(type) {
case string:
err = errors.New(x)
case error:
err = x
default:
err = errors.New("unknown panic")
}

errorHandler(w, r, err)
}
}()

next.ServeHTTP(w, r)
}

return http.HandlerFunc(fn)
}
}
47 changes: 47 additions & 0 deletions middlewarex/recoverer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package middlewarex

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestRecovererAbortHandler(t *testing.T) {
defer func() {
rcv := recover()
if rcv != http.ErrAbortHandler {
t.Fatalf("http.ErrAbortHandler should not be recovered")
}
}()

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic(http.ErrAbortHandler)
})

h := Recoverer(nil)(next)
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
}

func TestRecovererCustomErrorResponse(t *testing.T) {
plainTextErrorHandler := func(w http.ResponseWriter, r *http.Request, err error) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("internal server error: " + err.Error()))
}

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("test panic")
})

h := Recoverer(plainTextErrorHandler)(next)
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)

res := w.Result()
assert.Equal(t, res.StatusCode, http.StatusInternalServerError)
assert.Equal(t, w.Body.String(), "internal server error: test panic")
}