diff --git a/mux.go b/mux.go index ec86fcd..88c6d47 100644 --- a/mux.go +++ b/mux.go @@ -4,6 +4,7 @@ package pat import ( "net/http" "net/url" + "sort" "strings" ) @@ -91,11 +92,11 @@ import ( // Status to "405 Method Not Allowed". // // If the NotFound handler is set, then it is used whenever the pattern doesn't -// match the request path for the current method (and the Allow header is not -// altered). +// match the request path for any method (if it does match but for a different +// method, the 405 Method Not Allowed response is returned instead of a 404). type PatternServeMux struct { // NotFound, if set, is used whenever the request doesn't match any - // pattern for its method. NotFound should be set before serving any + // pattern for any method. NotFound should be set before serving any // requests. NotFound http.Handler handlers map[string][]*patHandler @@ -106,38 +107,80 @@ func New() *PatternServeMux { return &PatternServeMux{handlers: make(map[string][]*patHandler)} } -// ServeHTTP matches r.URL.Path against its routing table using the rules -// described above. -func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - for _, ph := range p.handlers[r.Method] { - if params, ok := ph.try(r.URL.Path); ok { - if len(params) > 0 && !ph.redirect { - r.URL.RawQuery = url.Values(params).Encode() + "&" + r.URL.RawQuery - } - ph.ServeHTTP(w, r) - return +// Lookup returns the handler that matches the specified method and +// path. If no registered handlers are found, it returns nil (that is, +// it doesn't return the NotFound handler or the handler to return +// the 405 Method Not Allowed response). This can be useful in a middleware +// to find out if a request actually matches a registered handler. +func (p *PatternServeMux) Lookup(method, path string) http.Handler { + for _, ph := range p.handlers[method] { + if params, ok := ph.try(path); ok { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(params) > 0 && !ph.redirect { + r.URL.RawQuery = url.Values(params).Encode() + "&" + r.URL.RawQuery + } + ph.ServeHTTP(w, r) + }) } } + return nil +} - if p.NotFound != nil { - p.NotFound.ServeHTTP(w, r) - return +// RegisteredPatterns returns a list of unique registered patterns. +// As long as a pattern has been registered for one method, it is +// returned. The list is lexically sorted. +func (p *PatternServeMux) RegisteredPatterns() []string { + set := make(map[string]bool) + for _, pats := range p.handlers { + for _, ph := range pats { + set[ph.pat] = true + } + } + list := make([]string, 0, len(set)) + for k := range set { + list = append(list, k) } + sort.Strings(list) + return list +} + +// AllowedMethods returns the list of registered methods for the +// specified path. +func (p *PatternServeMux) AllowedMethods(path string) []string { + return p.allowedMethods(path, "") +} - allowed := make([]string, 0, len(p.handlers)) +func (p *PatternServeMux) allowedMethods(path, skip string) []string { + var allowed []string for meth, handlers := range p.handlers { - if meth == r.Method { + if meth == skip { continue } for _, ph := range handlers { - if _, ok := ph.try(r.URL.Path); ok { + if _, ok := ph.try(path); ok { allowed = append(allowed, meth) } } } + return allowed +} +// ServeHTTP matches r.URL.Path against its routing table using the rules +// described above. +func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if h := p.Lookup(r.Method, r.URL.Path); h != nil { + h.ServeHTTP(w, r) + return + } + + allowed := p.allowedMethods(r.URL.Path, r.Method) if len(allowed) == 0 { + if p.NotFound != nil { + p.NotFound.ServeHTTP(w, r) + return + } + http.NotFound(w, r) return } diff --git a/mux_test.go b/mux_test.go index bce39bc..24bc82b 100644 --- a/mux_test.go +++ b/mux_test.go @@ -213,11 +213,11 @@ func TestNotFound(t *testing.T) { }) p.Post("/bar", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - for _, path := range []string{"/foo", "/bar"} { + for path, want := range map[string]int{"/foo": 123, "/bar": 405} { res := httptest.NewRecorder() p.ServeHTTP(res, newRequest("GET", path, nil)) - if res.Code != 123 { - t.Errorf("for path %q: got code %d; want 123", path, res.Code) + if res.Code != want { + t.Errorf("for path %q: got code %d; want %d", path, res.Code, want) } } } @@ -244,6 +244,96 @@ func TestMethodPatch(t *testing.T) { } } +func TestRegisteredPatterns(t *testing.T) { + p := New() + p.Get("/a", http.NotFoundHandler()) + p.Post("/b", http.NotFoundHandler()) + p.Del("/a", http.NotFoundHandler()) + p.Patch("/b", http.NotFoundHandler()) + p.Patch("/b/", http.NotFoundHandler()) + + pats := p.RegisteredPatterns() + want := []string{"/a", "/b", "/b/"} + if !reflect.DeepEqual(want, pats) { + t.Errorf("got %v; want %v", pats, want) + } +} + +func TestAllowedMethods(t *testing.T) { + p := New() + p.Get("/a", http.NotFoundHandler()) + p.Post("/a", http.NotFoundHandler()) + p.Post("/b", http.NotFoundHandler()) + p.Del("/a", http.NotFoundHandler()) + p.Patch("/b", http.NotFoundHandler()) + p.Patch("/b/", http.NotFoundHandler()) + + cases := []struct { + path string + want []string + }{ + {"/a", []string{"DELETE", "GET", "HEAD", "POST"}}, + {"/b", []string{"PATCH", "POST"}}, + {"/c", nil}, + } + for _, c := range cases { + got := p.AllowedMethods(c.path) + sort.Strings(got) + if !reflect.DeepEqual(c.want, got) { + t.Errorf("%s: got %v; want %v", c.path, got, c.want) + } + } +} + +type statusHandler int + +func (h statusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(int(h)) +} + +func TestLookup(t *testing.T) { + p := New() + + // create N different handlers + handlers := make([]http.Handler, 5) + for i := 0; i < len(handlers); i++ { + handlers[i] = statusHandler(i) + } + + register := []struct { + method string + path string + handlerIndex int + }{ + {"HEAD", "/a", 0}, + {"GET", "/a", 1}, + {"POST", "/a", 2}, + {"GET", "/b", 3}, + {"DELETE", "/b", 4}, + } + + // register the handlers + for _, r := range register { + p.Add(r.method, r.path, handlers[r.handlerIndex]) + } + + // assert the returned Lookup handler + for _, r := range register { + h := p.Lookup(r.method, r.path) + if h == nil { + t.Errorf("%s %s: lookup returned nil handler", r.method, r.path) + } + w := httptest.NewRecorder() + h.ServeHTTP(w, newRequest("", "/", nil)) + if w.Code != r.handlerIndex { + t.Errorf("%s %s: handler status code; got %v; want %v", r.method, r.path, w.Code, r.handlerIndex) + } + } + if h := p.Lookup("GET", "/c"); h != nil { + t.Errorf("GET /c: handler returned non-nil handler") + } +} + func newRequest(method, urlStr string, body io.Reader) *http.Request { req, err := http.NewRequest(method, urlStr, body) if err != nil {