Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package pat

import (
"context"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -97,15 +98,52 @@ type PatternServeMux struct {
// NotFound, if set, is used whenever the request doesn't match any
// pattern for its method. NotFound should be set before serving any
// requests.
NotFound http.Handler
handlers map[string][]*patHandler
NotFound http.Handler
handlers map[string][]*patHandler
middlewares []MiddlewareFunc
}

// MiddlewareFunc is a function which receives an http.Handler and returns
// another http.Handler.
// Typically, the returned handler is a closure which does something with the
// http.ResponseWriter and http.Request passed
// to it, and then calls the handler passed as parameter to the MiddlewareFunc.
type MiddlewareFunc func(http.Handler) http.Handler

// middleware interface is anything which implements a MiddlewareFunc named
// Middleware.
type middleware interface {
Middleware(handler http.Handler) http.Handler
}

// Middleware allows MiddlewareFunc to implement the middleware interface.
func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler {
return mw(handler)
}

type contextKey int

const (
// RouteKey is inspired by mux and other routers to preserve the matched
// pattern that can be referenced in the lifetime of a handler.
// This is useful for telemetry and instrumentation to use the pattern
// AND not the whole URL, which can result in cardinality explosion.
// For compatibility with most other mux like gorilla, mux, use count=1
RouteKey contextKey = iota + 1
)

// New returns a new PatternServeMux.
func New() *PatternServeMux {
return &PatternServeMux{handlers: make(map[string][]*patHandler)}
}

// Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
func (p *PatternServeMux) Use(mwf ...MiddlewareFunc) {
for _, fn := range mwf {
p.middlewares = append(p.middlewares, fn)
}
}

// ServeHTTP matches r.URL.Path against its routing table using the rules
// described above.
func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand All @@ -114,7 +152,17 @@ func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(params) > 0 && !ph.redirect {
r.URL.RawQuery = url.Values(params).Encode() + "&" + r.URL.RawQuery
}
ph.ServeHTTP(w, r)

// Set the routeKey in context to the current pattern.
ctx := context.WithValue(r.Context(), RouteKey, ph.pat)

h := ph.Handler
// Build middleware chain if no error was found
for i := len(p.middlewares) - 1; i >= 0; i-- {
h = p.middlewares[i].Middleware(h)
}

h.ServeHTTP(w, r.WithContext(ctx))
return
}
}
Expand Down
51 changes: 51 additions & 0 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ func TestPatRoutingHit(t *testing.T) {
if got, want := r.URL.Query().Get(":name"), "keith"; got != want {
t.Errorf("got %q, want %q", got, want)
}

if rk := r.Context().Value(RouteKey); rk != nil {
if rk.(string) != "/foo/:name" {
t.Errorf("routeKey %v does not match /foo/:name", rk)
}
} else {
t.Error("Should've found routeKey /foo/:name")
}

}))

p.ServeHTTP(nil, newRequest("GET", "/foo/keith?a=b", nil))
Expand Down Expand Up @@ -121,6 +130,14 @@ func TestPatNoParams(t *testing.T) {
if r.URL.RawQuery != "" {
t.Errorf("RawQuery was %q; should be empty", r.URL.RawQuery)
}

if rk := r.Context().Value(RouteKey); rk != nil {
if rk.(string) != "/foo/" {
t.Errorf("routeKey %v does not match /foo/:name", rk)
}
} else {
t.Error("Should've found routeKey /foo/:name")
}
}))

p.ServeHTTP(nil, newRequest("GET", "/foo/", nil))
Expand Down Expand Up @@ -262,6 +279,40 @@ func TestEscapedUrl(t *testing.T) {
}
}

func TestMiddleware(t *testing.T) {
mdf := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
q.Set("middleware", "passed")

r.URL.RawQuery = q.Encode()

h.ServeHTTP(w, r)
})
}

p := New()

var middlewareCalled bool
p.Get("/foo/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("%#v", r.URL.Query())
if got, want := r.URL.Query().Get("middleware"), "passed"; got != want {
t.Errorf("got %q, want %q", got, want)
} else {
middlewareCalled = true
}
}))

// use middleware
p.Use(mdf)

p.ServeHTTP(nil, newRequest("GET", "/foo/bad", nil))

if !middlewareCalled {
t.Error("middleware not called")
}
}

func newRequest(method, urlStr string, body io.Reader) *http.Request {
req, err := http.NewRequest(method, urlStr, body)
if err != nil {
Expand Down