Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 62 additions & 19 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package pat
import (
"net/http"
"net/url"
"sort"
"strings"
)

Expand Down Expand Up @@ -91,11 +92,11 @@ import (
// Status to "405 Method Not Allowed".
//
// If the NotFound handler is set, then it is used whenever the pattern doesn't
// match the request path for the current method (and the Allow header is not
// altered).
// match the request path for any method (if it does match but for a different
// method, the 405 Method Not Allowed response is returned instead of a 404).
type PatternServeMux struct {
// NotFound, if set, is used whenever the request doesn't match any
// pattern for its method. NotFound should be set before serving any
// pattern for any method. NotFound should be set before serving any
// requests.
NotFound http.Handler
handlers map[string][]*patHandler
Expand All @@ -106,38 +107,80 @@ func New() *PatternServeMux {
return &PatternServeMux{handlers: make(map[string][]*patHandler)}
}

// ServeHTTP matches r.URL.Path against its routing table using the rules
// described above.
func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for _, ph := range p.handlers[r.Method] {
if params, ok := ph.try(r.URL.Path); ok {
if len(params) > 0 && !ph.redirect {
r.URL.RawQuery = url.Values(params).Encode() + "&" + r.URL.RawQuery
}
ph.ServeHTTP(w, r)
return
// Lookup returns the handler that matches the specified method and
// path. If no registered handlers are found, it returns nil (that is,
// it doesn't return the NotFound handler or the handler to return
// the 405 Method Not Allowed response). This can be useful in a middleware
// to find out if a request actually matches a registered handler.
func (p *PatternServeMux) Lookup(method, path string) http.Handler {
for _, ph := range p.handlers[method] {
if params, ok := ph.try(path); ok {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if len(params) > 0 && !ph.redirect {
r.URL.RawQuery = url.Values(params).Encode() + "&" + r.URL.RawQuery
}
ph.ServeHTTP(w, r)
})
}
}
return nil
}

if p.NotFound != nil {
p.NotFound.ServeHTTP(w, r)
return
// RegisteredPatterns returns a list of unique registered patterns.
// As long as a pattern has been registered for one method, it is
// returned. The list is lexically sorted.
func (p *PatternServeMux) RegisteredPatterns() []string {
set := make(map[string]bool)
for _, pats := range p.handlers {
for _, ph := range pats {
set[ph.pat] = true
}
}
list := make([]string, 0, len(set))
for k := range set {
list = append(list, k)
}
sort.Strings(list)
return list
}

// AllowedMethods returns the list of registered methods for the
// specified path.
func (p *PatternServeMux) AllowedMethods(path string) []string {
return p.allowedMethods(path, "")
}

allowed := make([]string, 0, len(p.handlers))
func (p *PatternServeMux) allowedMethods(path, skip string) []string {
var allowed []string
for meth, handlers := range p.handlers {
if meth == r.Method {
if meth == skip {
continue
}

for _, ph := range handlers {
if _, ok := ph.try(r.URL.Path); ok {
if _, ok := ph.try(path); ok {
allowed = append(allowed, meth)
}
}
}
return allowed
}

// ServeHTTP matches r.URL.Path against its routing table using the rules
// described above.
func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if h := p.Lookup(r.Method, r.URL.Path); h != nil {
h.ServeHTTP(w, r)
return
}

allowed := p.allowedMethods(r.URL.Path, r.Method)
if len(allowed) == 0 {
if p.NotFound != nil {
p.NotFound.ServeHTTP(w, r)
return
}

http.NotFound(w, r)
return
}
Expand Down
96 changes: 93 additions & 3 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ func TestNotFound(t *testing.T) {
})
p.Post("/bar", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

for _, path := range []string{"/foo", "/bar"} {
for path, want := range map[string]int{"/foo": 123, "/bar": 405} {
res := httptest.NewRecorder()
p.ServeHTTP(res, newRequest("GET", path, nil))
if res.Code != 123 {
t.Errorf("for path %q: got code %d; want 123", path, res.Code)
if res.Code != want {
t.Errorf("for path %q: got code %d; want %d", path, res.Code, want)
}
}
}
Expand All @@ -244,6 +244,96 @@ func TestMethodPatch(t *testing.T) {
}
}

func TestRegisteredPatterns(t *testing.T) {
p := New()
p.Get("/a", http.NotFoundHandler())
p.Post("/b", http.NotFoundHandler())
p.Del("/a", http.NotFoundHandler())
p.Patch("/b", http.NotFoundHandler())
p.Patch("/b/", http.NotFoundHandler())

pats := p.RegisteredPatterns()
want := []string{"/a", "/b", "/b/"}
if !reflect.DeepEqual(want, pats) {
t.Errorf("got %v; want %v", pats, want)
}
}

func TestAllowedMethods(t *testing.T) {
p := New()
p.Get("/a", http.NotFoundHandler())
p.Post("/a", http.NotFoundHandler())
p.Post("/b", http.NotFoundHandler())
p.Del("/a", http.NotFoundHandler())
p.Patch("/b", http.NotFoundHandler())
p.Patch("/b/", http.NotFoundHandler())

cases := []struct {
path string
want []string
}{
{"/a", []string{"DELETE", "GET", "HEAD", "POST"}},
{"/b", []string{"PATCH", "POST"}},
{"/c", nil},
}
for _, c := range cases {
got := p.AllowedMethods(c.path)
sort.Strings(got)
if !reflect.DeepEqual(c.want, got) {
t.Errorf("%s: got %v; want %v", c.path, got, c.want)
}
}
}

type statusHandler int

func (h statusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(int(h))
}

func TestLookup(t *testing.T) {
p := New()

// create N different handlers
handlers := make([]http.Handler, 5)
for i := 0; i < len(handlers); i++ {
handlers[i] = statusHandler(i)
}

register := []struct {
method string
path string
handlerIndex int
}{
{"HEAD", "/a", 0},
{"GET", "/a", 1},
{"POST", "/a", 2},
{"GET", "/b", 3},
{"DELETE", "/b", 4},
}

// register the handlers
for _, r := range register {
p.Add(r.method, r.path, handlers[r.handlerIndex])
}

// assert the returned Lookup handler
for _, r := range register {
h := p.Lookup(r.method, r.path)
if h == nil {
t.Errorf("%s %s: lookup returned nil handler", r.method, r.path)
}
w := httptest.NewRecorder()
h.ServeHTTP(w, newRequest("", "/", nil))
if w.Code != r.handlerIndex {
t.Errorf("%s %s: handler status code; got %v; want %v", r.method, r.path, w.Code, r.handlerIndex)
}
}
if h := p.Lookup("GET", "/c"); h != nil {
t.Errorf("GET /c: handler returned non-nil handler")
}
}

func newRequest(method, urlStr string, body io.Reader) *http.Request {
req, err := http.NewRequest(method, urlStr, body)
if err != nil {
Expand Down