diff --git a/cachers/http.go b/cachers/http.go index 19c08a1..c982308 100644 --- a/cachers/http.go +++ b/cachers/http.go @@ -9,7 +9,10 @@ import ( "log" "net/http" "strconv" + "strings" + "sync" + "github.com/bradfitz/go-tool-cache/consts" "github.com/pierrec/lz4/v4" ) @@ -42,6 +45,11 @@ type HTTPClient struct { // Get returns a cache miss and Put returns the local disk result, // silently ignoring any HTTP failures (connection errors, server errors, etc.). BestEffortHTTP bool + + // serverCaps records capabilities discovered from the server's + // Gocached-Cap response header. Keys are capability tokens + // (e.g. "putlz4"); values are always true. + serverCaps sync.Map } func (c *HTTPClient) httpClient() *http.Client { @@ -51,6 +59,25 @@ func (c *HTTPClient) httpClient() *http.Client { return http.DefaultClient } +// noteServerCaps records any capabilities declared by the server in the +// Gocached-Cap response header. +func (c *HTTPClient) noteServerCaps(res *http.Response) { + for _, v := range res.Header.Values(consts.CapHeader) { + for tok := range strings.SplitSeq(v, ",") { + if cap := strings.TrimSpace(tok); cap != "" { + if _, ok := c.serverCaps.Load(cap); !ok { + c.serverCaps.Store(cap, true) + } + } + } + } +} + +func (c *HTTPClient) serverHasCap(cap string) bool { + _, ok := c.serverCaps.Load(cap) + return ok +} + // tryDrainResponse reads and throws away a small bounded amount of data from // res.Body. This is a best-effort attempt to allow connection reuse. (Go's // HTTP/1 Transport won't reuse a TCP connection unless you fully consume HTTP @@ -115,6 +142,7 @@ func (c *HTTPClient) Get(ctx context.Context, actionID string) (outputID, diskPa } defer res.Body.Close() defer tryDrainResponse(res) + c.noteServerCaps(res) if res.StatusCode == http.StatusNotFound { return "", "", nil } @@ -167,6 +195,7 @@ func (c *HTTPClient) Get(ctx context.Context, actionID string) (outputID, diskPa } defer res.Body.Close() defer tryDrainResponse(res) + c.noteServerCaps(res) if res.StatusCode == http.StatusNotFound { return "", "", nil } @@ -214,8 +243,31 @@ func (c *HTTPClient) Put(ctx context.Context, actionID, outputID string, size in } }() - req, _ := http.NewRequestWithContext(ctx, "PUT", c.BaseURL+"/"+actionID+"/"+outputID, bytes.NewReader(buf)) - req.ContentLength = size + var reqBody io.Reader = bytes.NewReader(buf) + reqContentLength := size + compressPut := size > 0 && c.serverHasCap(consts.CapPutLZ4) + if compressPut { + var cbuf bytes.Buffer + lzw := lz4.NewWriter(&cbuf) + if err := lzw.Apply(lz4.SizeOption(uint64(size))); err != nil { + return "", err + } + if _, err := lzw.Write(buf); err != nil { + return "", err + } + if err := lzw.Close(); err != nil { + return "", err + } + reqBody = bytes.NewReader(cbuf.Bytes()) + reqContentLength = int64(cbuf.Len()) + } + + req, _ := http.NewRequestWithContext(ctx, "PUT", c.BaseURL+"/"+actionID+"/"+outputID, reqBody) + req.ContentLength = reqContentLength + if compressPut { + req.Header.Set("Content-Encoding", "lz4") + req.Header.Set("X-Uncompressed-Length", strconv.FormatInt(size, 10)) + } if c.AccessToken != "" { req.Header.Set("Authorization", "Bearer "+c.AccessToken) } @@ -226,6 +278,7 @@ func (c *HTTPClient) Put(ctx context.Context, actionID, outputID string, size in httpErr = err } else { defer res.Body.Close() + c.noteServerCaps(res) if res.StatusCode != http.StatusNoContent { msg := tryReadErrorMessage(res) log.Printf("error PUT /%s/%s: %v, %s", actionID, outputID, res.Status, msg) diff --git a/consts/consts.go b/consts/consts.go new file mode 100644 index 0000000..205d536 --- /dev/null +++ b/consts/consts.go @@ -0,0 +1,11 @@ +// Package consts defines shared constants used by both the gocached server +// and the cachers HTTP client. +package consts + +// CapHeader is the HTTP response header used by the server to advertise +// its capabilities as a comma-separated list of tokens. +const CapHeader = "Gocached-Cap" + +// CapPutLZ4 is the Gocached-Cap token indicating that the server accepts +// lz4-compressed PUT request bodies (Content-Encoding: lz4). +const CapPutLZ4 = "putlz4" diff --git a/gocached/gocached.go b/gocached/gocached.go index 0df4f5b..b530c96 100644 --- a/gocached/gocached.go +++ b/gocached/gocached.go @@ -54,11 +54,13 @@ import ( "path/filepath" "reflect" "slices" + "strconv" "strings" "sync" "sync/atomic" "time" + "github.com/bradfitz/go-tool-cache/consts" ijwt "github.com/bradfitz/go-tool-cache/gocached/internal/jwt" "github.com/pierrec/lz4/v4" "github.com/prometheus/client_golang/prometheus" @@ -563,6 +565,8 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { srv.logf("ServeHTTP: %s %s", r.Method, r.RequestURI) } + w.Header().Set(consts.CapHeader, consts.CapPutLZ4) + var sessionData *sessionData // remains nil for unauthenticated requests. reqStats := &stats{} defer func() { @@ -962,14 +966,37 @@ func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) return } + // Determine the body reader and content size. If the client sent + // lz4-compressed data, wrap the body with a decompressor and use the + // uncompressed size for all downstream logic (hashing, inline vs disk + // decisions, etc.). The server will re-compress when writing to disk. + // The goal is reducing bytes on the wire, not saving CPU, so the + // decompress-then-recompress round trip is fine. + var bodyReader io.Reader = r.Body + contentSize := r.ContentLength + if r.Header.Get("Content-Encoding") == "lz4" { + sizeStr := r.Header.Get("X-Uncompressed-Length") + if sizeStr == "" { + http.Error(w, "lz4-compressed PUT missing X-Uncompressed-Length", http.StatusBadRequest) + return + } + var err error + contentSize, err = strconv.ParseInt(sizeStr, 10, 64) + if err != nil { + http.Error(w, "invalid X-Uncompressed-Length", http.StatusBadRequest) + return + } + bodyReader = lz4.NewReader(r.Body) + } + hasher := sha256.New() - hashingBody := io.TeeReader(r.Body, hasher) + hashingBody := io.TeeReader(bodyReader, hasher) - storedSize := r.ContentLength - uncompressedSize := r.ContentLength + storedSize := contentSize + uncompressedSize := contentSize var smallData []byte - if r.ContentLength <= smallObjectSize { + if contentSize <= smallObjectSize { // Store small objects inline in the database. var err error smallData, err = io.ReadAll(hashingBody) @@ -978,15 +1005,13 @@ func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) http.Error(w, "Read content error", http.StatusInternalServerError) return } - if int64(len(smallData)) != r.ContentLength { - // This check is redundant with net/http's validation, but - // for extra clarity. + if int64(len(smallData)) != contentSize { http.Error(w, "bad content length", http.StatusInternalServerError) return } } else { // For larger objects, we store them on disk (lz4 compressed). - diskSize, err := s.writeDiskBlob(r.ContentLength, hashingBody) + diskSize, err := s.writeDiskBlob(contentSize, hashingBody) if err != nil { s.logf("Write disk blob error: %v", err) http.Error(w, "Write disk blob error", http.StatusInternalServerError) @@ -1049,7 +1074,7 @@ func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) } stats.Puts++ - stats.PutsBytes += r.ContentLength + stats.PutsBytes += contentSize if smallData != nil { stats.PutsInline++ } diff --git a/gocached/gocached_test.go b/gocached/gocached_test.go index c87fc6a..baccf30 100644 --- a/gocached/gocached_test.go +++ b/gocached/gocached_test.go @@ -614,6 +614,84 @@ func TestLZ4Storage(t *testing.T) { } } +// roundTripFunc is an http.RoundTripper implemented as a function. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestCompressPuts(t *testing.T) { + st := newServerTester(t) + + // Track whether PUT requests are sent with lz4 compression on the wire. + var putLZ4Count atomic.Int32 + c := st.mkClient() + c.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if req.Method == "PUT" && req.Header.Get("Content-Encoding") == "lz4" { + putLZ4Count.Add(1) + } + return http.DefaultTransport.RoundTrip(req) + }), + } + + // The first Put goes out uncompressed because the client hasn't + // yet seen a Gocached-Cap response. The Put response itself + // carries Gocached-Cap: putlz4, so subsequent puts are compressed. + st.wantPut(c, "0010", "9010", "first-put") + if n := putLZ4Count.Load(); n != 0 { + t.Fatalf("expected 0 lz4 puts before cap discovery, got %d", n) + } + + // Now the client knows the server supports putlz4. + tests := []struct { + name string + val string + }{ + {"empty", ""}, + {"small_inline", "hello"}, + {"inline_max", strings.Repeat("a", smallObjectSize)}, + {"disk_no_lz4", strings.Repeat("b", smallObjectSize+1)}, + {"disk_lz4_threshold", strings.Repeat("c", lz4CompressThreshold)}, + {"disk_large", strings.Repeat("d", 4096)}, + } + + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actionID := fmt.Sprintf("%04x", i+0x20) + outputID := fmt.Sprintf("%04x", i+0xa0) + + st.wantPut(c, actionID, outputID, tt.val) + + // Get from a fresh client to verify the server correctly + // decompressed the lz4 wire body and stored the original data. + cFresh := st.mkClient() + st.wantGet(cFresh, actionID, outputID, tt.val) + }) + } + + // All puts with size > 0 should have been lz4-compressed on the wire. + // "empty" has size=0 and is not compressed. + wantLZ4 := int32(len(tests) - 1) // all except "empty" + if n := putLZ4Count.Load(); n != wantLZ4 { + t.Errorf("expected %d lz4 puts after cap discovery, got %d", wantLZ4, n) + } + + // Verify dedup: same content from a fresh client (no caps, sends + // uncompressed) and the cap-aware client (sends compressed) produces + // the same blob (same SHA-256). + cNoCaps := st.mkClient() + largeVal := strings.Repeat("e", 4096) + + st.wantPut(cNoCaps, "aa01", "bb01", largeVal) + st.wantPut(c, "aa02", "bb01", largeVal) + + cFresh := st.mkClient() + st.wantGet(cFresh, "aa01", "bb01", largeVal) + st.wantGet(cFresh, "aa02", "bb01", largeVal) +} + func TestClientConnReuse(t *testing.T) { st := newServerTester(t)