From abd9555e478f7463df5f5e3da7016fd728e2a8bb Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Wed, 20 Jul 2016 10:23:41 -0400 Subject: [PATCH 1/5] start extracting the code for AllowedMethods and Lookup --- mux.go | 60 ++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/mux.go b/mux.go index ec86fcd..5079a27 100644 --- a/mux.go +++ b/mux.go @@ -91,13 +91,14 @@ import ( // Status to "405 Method Not Allowed". // // If the NotFound handler is set, then it is used whenever the pattern doesn't -// match the request path for the current method (and the Allow header is not -// altered). +// match the request path for any method (if it does match but for a different +// method, the 405 Method Not Allowed response is returned instead of a 404). type PatternServeMux struct { // NotFound, if set, is used whenever the request doesn't match any - // pattern for its method. NotFound should be set before serving any + // pattern for any method. NotFound should be set before serving any // requests. NotFound http.Handler + handlers map[string][]*patHandler } @@ -106,6 +107,37 @@ func New() *PatternServeMux { return &PatternServeMux{handlers: make(map[string][]*patHandler)} } +// Lookup returns the handler that matches the specified method and +// path. If no registered handlers are found, it returns nil (that is, +// it doesn't return the NotFound handler or the handler to return +// the 405 Method Not Allowed response). This can be useful in a middleware +// to find out if a request actually matches a registered handler. +func (p *PatternServeMux) Lookup(method, path string) http.Handler { + return nil +} + +// AllowedMethods returns the list of registered methods for the +// specified path. +func (p *PatternServeMux) AllowedMethods(path string) []string { + return p.allowedMethods(path, "") +} + +func (p *PatternServeMux) allowedMethods(path, skip string) []string { + var allowed []string + for meth, handlers := range p.handlers { + if meth == skip { + continue + } + + for _, ph := range handlers { + if _, ok := ph.try(path); ok { + allowed = append(allowed, meth) + } + } + } + return allowed +} + // ServeHTTP matches r.URL.Path against its routing table using the rules // described above. func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -119,25 +151,13 @@ func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - if p.NotFound != nil { - p.NotFound.ServeHTTP(w, r) - return - } - - allowed := make([]string, 0, len(p.handlers)) - for meth, handlers := range p.handlers { - if meth == r.Method { - continue - } - - for _, ph := range handlers { - if _, ok := ph.try(r.URL.Path); ok { - allowed = append(allowed, meth) - } + allowed := p.allowedMethods(r.URL.Path, r.Method) + if len(allowed) == 0 { + if p.NotFound != nil { + p.NotFound.ServeHTTP(w, r) + return } - } - if len(allowed) == 0 { http.NotFound(w, r) return } From 2c162a6b936885ada7fae4fce211fbea603226a1 Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Wed, 20 Jul 2016 18:45:27 -0400 Subject: [PATCH 2/5] implement Lookup, fix NotFound behaviour in test --- mux.go | 22 +++++++++++++--------- mux_test.go | 6 +++--- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/mux.go b/mux.go index 5079a27..7290bce 100644 --- a/mux.go +++ b/mux.go @@ -98,7 +98,6 @@ type PatternServeMux struct { // pattern for any method. NotFound should be set before serving any // requests. NotFound http.Handler - handlers map[string][]*patHandler } @@ -113,6 +112,16 @@ func New() *PatternServeMux { // the 405 Method Not Allowed response). This can be useful in a middleware // to find out if a request actually matches a registered handler. func (p *PatternServeMux) Lookup(method, path string) http.Handler { + for _, ph := range p.handlers[method] { + if params, ok := ph.try(path); ok { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(params) > 0 && !ph.redirect { + r.URL.RawQuery = url.Values(params).Encode() + "&" + r.URL.RawQuery + } + ph.ServeHTTP(w, r) + }) + } + } return nil } @@ -141,14 +150,9 @@ func (p *PatternServeMux) allowedMethods(path, skip string) []string { // ServeHTTP matches r.URL.Path against its routing table using the rules // described above. func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - for _, ph := range p.handlers[r.Method] { - if params, ok := ph.try(r.URL.Path); ok { - if len(params) > 0 && !ph.redirect { - r.URL.RawQuery = url.Values(params).Encode() + "&" + r.URL.RawQuery - } - ph.ServeHTTP(w, r) - return - } + if h := p.Lookup(r.Method, r.URL.Path); h != nil { + h.ServeHTTP(w, r) + return } allowed := p.allowedMethods(r.URL.Path, r.Method) diff --git a/mux_test.go b/mux_test.go index bce39bc..48c324b 100644 --- a/mux_test.go +++ b/mux_test.go @@ -213,11 +213,11 @@ func TestNotFound(t *testing.T) { }) p.Post("/bar", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - for _, path := range []string{"/foo", "/bar"} { + for path, want := range map[string]int{"/foo": 123, "/bar": 405} { res := httptest.NewRecorder() p.ServeHTTP(res, newRequest("GET", path, nil)) - if res.Code != 123 { - t.Errorf("for path %q: got code %d; want 123", path, res.Code) + if res.Code != want { + t.Errorf("for path %q: got code %d; want %d", path, res.Code, want) } } } From 848a2886e148a89db588d001608d9b9eb55ad15b Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Thu, 21 Jul 2016 09:03:27 -0400 Subject: [PATCH 3/5] test the Lookup and AllowedMethods methods --- mux_test.go | 76 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/mux_test.go b/mux_test.go index 48c324b..c48aabc 100644 --- a/mux_test.go +++ b/mux_test.go @@ -244,6 +244,82 @@ func TestMethodPatch(t *testing.T) { } } +func TestAllowedMethods(t *testing.T) { + p := New() + p.Get("/a", http.NotFoundHandler()) + p.Post("/a", http.NotFoundHandler()) + p.Post("/b", http.NotFoundHandler()) + p.Del("/a", http.NotFoundHandler()) + p.Patch("/b", http.NotFoundHandler()) + p.Patch("/b/", http.NotFoundHandler()) + + cases := []struct { + path string + want []string + }{ + {"/a", []string{"DELETE", "GET", "HEAD", "POST"}}, + {"/b", []string{"PATCH", "POST"}}, + {"/c", nil}, + } + for _, c := range cases { + got := p.AllowedMethods(c.path) + sort.Strings(got) + if !reflect.DeepEqual(c.want, got) { + t.Errorf("%s: got %v; want %v", c.path, got, c.want) + } + } +} + +type statusHandler int + +func (h statusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(int(h)) +} + +func TestLookup(t *testing.T) { + p := New() + + // create N different handlers + handlers := make([]http.Handler, 5) + for i := 0; i < len(handlers); i++ { + handlers[i] = statusHandler(i) + } + + register := []struct { + method string + path string + handlerIndex int + }{ + {"HEAD", "/a", 0}, + {"GET", "/a", 1}, + {"POST", "/a", 2}, + {"GET", "/b", 3}, + {"DELETE", "/b", 4}, + } + + // register the handlers + for _, r := range register { + p.Add(r.method, r.path, handlers[r.handlerIndex]) + } + + // assert the returned Lookup handler + for _, r := range register { + h := p.Lookup(r.method, r.path) + if h == nil { + t.Errorf("%s %s: lookup returned nil handler", r.method, r.path) + } + w := httptest.NewRecorder() + req, _ := http.NewRequest("", "/", nil) + h.ServeHTTP(w, req) + if w.Code != r.handlerIndex { + t.Errorf("%s %s: handler status code; got %v; want %v", r.method, r.path, w.Code, r.handlerIndex) + } + } + if h := p.Lookup("GET", "/c"); h != nil { + t.Errorf("GET /c: handler returned non-nil handler") + } +} + func newRequest(method, urlStr string, body io.Reader) *http.Request { req, err := http.NewRequest(method, urlStr, body) if err != nil { From 2080f06d30c29e0b8d61eaa64378e2b26bce3822 Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Thu, 21 Jul 2016 09:14:32 -0400 Subject: [PATCH 4/5] use the newRequest helper in tests --- mux_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mux_test.go b/mux_test.go index c48aabc..3a7fdbb 100644 --- a/mux_test.go +++ b/mux_test.go @@ -309,8 +309,7 @@ func TestLookup(t *testing.T) { t.Errorf("%s %s: lookup returned nil handler", r.method, r.path) } w := httptest.NewRecorder() - req, _ := http.NewRequest("", "/", nil) - h.ServeHTTP(w, req) + h.ServeHTTP(w, newRequest("", "/", nil)) if w.Code != r.handlerIndex { t.Errorf("%s %s: handler status code; got %v; want %v", r.method, r.path, w.Code, r.handlerIndex) } From 71f069ac05a0f93169200fd95523fff8d790e726 Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Thu, 21 Jul 2016 21:46:06 -0400 Subject: [PATCH 5/5] add RegisteredPatterns method --- mux.go | 19 +++++++++++++++++++ mux_test.go | 15 +++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/mux.go b/mux.go index 7290bce..88c6d47 100644 --- a/mux.go +++ b/mux.go @@ -4,6 +4,7 @@ package pat import ( "net/http" "net/url" + "sort" "strings" ) @@ -125,6 +126,24 @@ func (p *PatternServeMux) Lookup(method, path string) http.Handler { return nil } +// RegisteredPatterns returns a list of unique registered patterns. +// As long as a pattern has been registered for one method, it is +// returned. The list is lexically sorted. +func (p *PatternServeMux) RegisteredPatterns() []string { + set := make(map[string]bool) + for _, pats := range p.handlers { + for _, ph := range pats { + set[ph.pat] = true + } + } + list := make([]string, 0, len(set)) + for k := range set { + list = append(list, k) + } + sort.Strings(list) + return list +} + // AllowedMethods returns the list of registered methods for the // specified path. func (p *PatternServeMux) AllowedMethods(path string) []string { diff --git a/mux_test.go b/mux_test.go index 3a7fdbb..24bc82b 100644 --- a/mux_test.go +++ b/mux_test.go @@ -244,6 +244,21 @@ func TestMethodPatch(t *testing.T) { } } +func TestRegisteredPatterns(t *testing.T) { + p := New() + p.Get("/a", http.NotFoundHandler()) + p.Post("/b", http.NotFoundHandler()) + p.Del("/a", http.NotFoundHandler()) + p.Patch("/b", http.NotFoundHandler()) + p.Patch("/b/", http.NotFoundHandler()) + + pats := p.RegisteredPatterns() + want := []string{"/a", "/b", "/b/"} + if !reflect.DeepEqual(want, pats) { + t.Errorf("got %v; want %v", pats, want) + } +} + func TestAllowedMethods(t *testing.T) { p := New() p.Get("/a", http.NotFoundHandler())