diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 5b6d7d4..b76b91a 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -17,5 +17,5 @@ jobs: with: args: sh -c "go get -t -v ./...; gofmt -w -s . && git diff --exit-code; - go tool vet .; + go vet .; go test -v -race ./..." diff --git a/httpcache.go b/httpcache.go index 38982f7..69fe18f 100644 --- a/httpcache.go +++ b/httpcache.go @@ -134,7 +134,7 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool { // If there is a stale Response, then any validators it contains will be set on the new request // to give the server a chance to respond with NotModified. If this happens, then the cached Response // will be returned. -func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { +func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { // skipcq: GO-R1005 cacheKey := cacheKey(req) cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" var cachedResp *http.Response @@ -291,7 +291,7 @@ var clock timer = &realClock{} // // Because this is only a private cache, 'public' and 'private' in cache-control aren't // signficant. Similarly, smax-age isn't used. -func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { +func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { // skipcq: GO-R1005 respCacheControl := parseCacheControl(respHeaders) reqCacheControl := parseCacheControl(reqHeaders) if _, ok := reqCacheControl["no-cache"]; ok { @@ -435,7 +435,7 @@ func getEndToEndHeaders(respHeaders http.Header) []string { hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{} } } - endToEndHeaders := []string{} + var endToEndHeaders []string for respHeader := range respHeaders { if _, ok := hopByHopHeaders[respHeader]; !ok { endToEndHeaders = append(endToEndHeaders, respHeader) @@ -527,6 +527,9 @@ type cachingReadCloser struct { OnEOF func(io.Reader) buf bytes.Buffer // buf stores a copy of the content of R. + + cached bool + readed bool } // Read reads the next len(p) bytes from R or until R is drained. The @@ -534,15 +537,48 @@ type cachingReadCloser struct { // return, err is io.EOF and OnEOF is called with a full copy of what // has been read so far. func (r *cachingReadCloser) Read(p []byte) (n int, err error) { + r.readed = true n, err = r.R.Read(p) r.buf.Write(p[:n]) - if err == io.EOF || n < len(p) { - r.OnEOF(bytes.NewReader(r.buf.Bytes())) + // we only get an io.EOF if we have a Content-Length (event with + // Transfer-Encoding: chunked we might not get an EOF error marking + // the end). Also inn the very weird case that + // none of those are provided, we can only know that we have + // readed the content, because something was been read and + // close was called. + if err == io.EOF { + r.cacheIt() } return n, err } +func (r *cachingReadCloser) cacheIt() { + if r.cached { + return + } + r.cached = true + if !r.readed { + // if there was no attempt to read the body, we assume + // is not used. + return + } + r.OnEOF(bytes.NewReader(r.buf.Bytes())) +} + func (r *cachingReadCloser) Close() error { + // it might happen that when no 'Content-Length' is provided, + // and no 'Transfer-Encoding: chunked' is set, that we do not + // know when the body is fully read. For example, a json decoder + // will read a body until the end of a valid block of json (and + // would not keep reading beyond that): so a body of `{"k":"v"}foo` + // with extra characters, would not be fully read, never reaching + // the EOF, so would not be cached. However, since the connection + // is closed at some point, we can assume that the readed values + // were the good ones for the response. + // + // The problem would be if we had not read anything from the respose + // or we had read a partial response :/ + r.cacheIt() return r.R.Close() } diff --git a/httpcache_test.go b/httpcache_test.go index 99d8099..7f416d7 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "flag" + "fmt" "io" "net/http" "net/http/httptest" @@ -47,7 +48,7 @@ func setup() { mux := http.NewServeMux() s.server = httptest.NewServer(mux) - mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Cache-Control", "max-age=3600") })) @@ -71,7 +72,7 @@ func setup() { w.Write([]byte("Some text content")) })) - mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Cache-Control", "no-store") })) @@ -93,27 +94,27 @@ func setup() { w.Header().Set("last-modified", lm) })) - mux.HandleFunc("/varyaccept", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/varyaccept", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Cache-Control", "max-age=3600") w.Header().Set("Content-Type", "text/plain") w.Header().Set("Vary", "Accept") w.Write([]byte("Some text content")) })) - mux.HandleFunc("/doublevary", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/doublevary", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Cache-Control", "max-age=3600") w.Header().Set("Content-Type", "text/plain") w.Header().Set("Vary", "Accept, Accept-Language") w.Write([]byte("Some text content")) })) - mux.HandleFunc("/2varyheaders", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/2varyheaders", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Cache-Control", "max-age=3600") w.Header().Set("Content-Type", "text/plain") w.Header().Add("Vary", "Accept") w.Header().Add("Vary", "Accept-Language") w.Write([]byte("Some text content")) })) - mux.HandleFunc("/varyunused", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/varyunused", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Cache-Control", "max-age=3600") w.Header().Set("Content-Type", "text/plain") w.Header().Set("Vary", "X-Madeup-Header") @@ -144,11 +145,11 @@ func setup() { })) // Take 3 seconds to return 200 OK (for testing client timeouts). - mux.HandleFunc("/3seconds", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/3seconds", http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { time.Sleep(3 * time.Second) })) - mux.HandleFunc("/infinite", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/infinite", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { for { select { case <-s.done: @@ -159,13 +160,65 @@ func setup() { } })) - mux.HandleFunc("/json", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/fast/json", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(longJson))) + w.Write(longJson) + })) + mux.HandleFunc("/slow/json", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(longJson))) + + f, ok := w.(http.Flusher) + if !ok { + w.WriteHeader(http.StatusOK) + return + } + first := len(longJson) / 3 + second := first * 2 + w.Write(longJson[:first]) + f.Flush() + time.Sleep(time.Millisecond * 50) + w.Write(longJson[first:second]) + f.Flush() + time.Sleep(time.Millisecond * 50) + w.Write(longJson[second:]) + f.Flush() + })) + + mux.HandleFunc("/chunked/json", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "application/json") + + f, ok := w.(http.Flusher) + if !ok { + w.WriteHeader(http.StatusOK) + return + } + f.Flush() + + first := len(longJson) / 3 + second := first * 2 + w.Write(longJson[:first]) + f.Flush() + time.Sleep(time.Millisecond * 50) + w.Write(longJson[first:second]) + f.Flush() + time.Sleep(time.Millisecond * 50) + w.Write(longJson[second:]) + f.Flush() + })) + + mux.HandleFunc("/weird/json", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Cache-Control", "max-age=3600") w.Header().Set("Content-Type", "application/json") // This will force using bufio.Read() instead of chunkedReader.Read() // to miss the EOF. w.Header().Set("Transfer-encoding", "identity") - json.NewEncoder(w).Encode(map[string]string{"k": "v"}) + // json.NewEncoder(w).Encode(map[string]string{"k": "v"}) + w.Write(([]byte)(`{"k": "v"}foo`)) })) } @@ -262,7 +315,7 @@ func TestDontServeHeadResponseToGetRequest(t *testing.T) { } } -func TestDontStorePartialRangeInCache(t *testing.T) { +func TestDontStorePartialRangeInCache(t *testing.T) { // skipcq: GO-R1005 resetTest() { req, err := http.NewRequest("GET", s.server.URL+"/range", nil) @@ -410,7 +463,113 @@ func TestCacheOnlyIfBodyRead(t *testing.T) { func TestCacheOnJsonBodyRead(t *testing.T) { resetTest() { - req, err := http.NewRequest("GET", s.server.URL+"/json", nil) + req, err := http.NewRequest("GET", s.server.URL+"/weird/json", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var r json.RawMessage + err = json.NewDecoder(resp.Body).Decode(&r) + if err != nil { + t.Fatal(err) + } + + // the response is cached on close, because server + // is not returning 'Content-Length' nor + // 'Transfer-Encoding: chunked' + resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatalf("XFromCache header isn't blank") + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/weird/json", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf("XFromCache header isn't set") + } + } +} + +func TestCacheOnChunkedJsonBodyRead(t *testing.T) { + resetTest() + { + req, err := http.NewRequest("GET", s.server.URL+"/chunked/json", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var r json.RawMessage + err = json.NewDecoder(resp.Body).Decode(&r) + if err != nil { + t.Fatal(err) + } + + // in this case, even when Close was not called yet + // since is used chunked it can detect EOF, before Close + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("Bad status code: %d", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "" { + t.Fatalf("XFromCache header isn't blank") + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/chunked/json", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf("XFromCache header isn't set") + } + } +} + +func TestCacheOnFastJsonBodyRead(t *testing.T) { + resetTest() + { + req, err := http.NewRequest("GET", s.server.URL+"/fast/json", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var r json.RawMessage + err = json.NewDecoder(resp.Body).Decode(&r) + if err != nil { + t.Fatal(err) + } + + // in this case, even when Close was not called yet + // since is used chunked it can detect EOF, before Close + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatalf("XFromCache header isn't blank") + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/fast/json", nil) if err != nil { t.Fatal(err) } @@ -419,17 +578,38 @@ func TestCacheOnJsonBodyRead(t *testing.T) { t.Fatal(err) } defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf("XFromCache header isn't set") + } + } +} + +func TestCacheOnSlowJsonBodyRead(t *testing.T) { + resetTest() + { + req, err := http.NewRequest("GET", s.server.URL+"/slow/json", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } var r json.RawMessage err = json.NewDecoder(resp.Body).Decode(&r) if err != nil { t.Fatal(err) } + + // in this case, even when Close was not called yet + // since is used chunked it can detect EOF, before Close + defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatalf("XFromCache header isn't blank") } } { - req, err := http.NewRequest("GET", s.server.URL+"/json", nil) + req, err := http.NewRequest("GET", s.server.URL+"/slow/json", nil) if err != nil { t.Fatal(err) } @@ -763,7 +943,7 @@ func TestGetWithDoubleVary(t *testing.T) { } } -func TestGetWith2VaryHeaders(t *testing.T) { +func TestGetWith2VaryHeaders(t *testing.T) { // skipcq: GO-R1005 resetTest() // Tests that multiple Vary headers' comma-separated lists are // merged. See https://github.com/gregjones/httpcache/issues/27. @@ -1519,3 +1699,17 @@ func TestClientTimeout(t *testing.T) { t.Error("client.Do took 2+ seconds, want < 2 seconds") } } + +// we need a json that we can "stream" in small amount of bytes +var longJson []byte = ([]byte)(` +{ + "a": "a_1234567890123456789012345678901234567890", + "b": "b_1234567890123456789012345678901234567890", + "c": "c_1234567890123456789012345678901234567890", + "d": "d_1234567890123456789012345678901234567890", + "e": "e_1234567890123456789012345678901234567890", + "f": "f_1234567890123456789012345678901234567890", + "g": "g_1234567890123456789012345678901234567890", + "h": "h_1234567890123456789012345678901234567890", + "i": "i_1234567890123456789012345678901234567890" +}`)