Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions cachers/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import (
"log"
"net/http"
"strconv"
"strings"
"sync"

"github.com/bradfitz/go-tool-cache/consts"
"github.com/pierrec/lz4/v4"
)

Expand Down Expand Up @@ -42,6 +45,11 @@ type HTTPClient struct {
// Get returns a cache miss and Put returns the local disk result,
// silently ignoring any HTTP failures (connection errors, server errors, etc.).
BestEffortHTTP bool

// serverCaps records capabilities discovered from the server's
// Gocached-Cap response header. Keys are capability tokens
// (e.g. "putlz4"); values are always true.
serverCaps sync.Map
}

func (c *HTTPClient) httpClient() *http.Client {
Expand All @@ -51,6 +59,25 @@ func (c *HTTPClient) httpClient() *http.Client {
return http.DefaultClient
}

// noteServerCaps records any capabilities declared by the server in the
// Gocached-Cap response header.
func (c *HTTPClient) noteServerCaps(res *http.Response) {
for _, v := range res.Header.Values(consts.CapHeader) {
for tok := range strings.SplitSeq(v, ",") {
if cap := strings.TrimSpace(tok); cap != "" {
if _, ok := c.serverCaps.Load(cap); !ok {
c.serverCaps.Store(cap, true)
}
}
}
}
}

func (c *HTTPClient) serverHasCap(cap string) bool {
_, ok := c.serverCaps.Load(cap)
return ok
}

// tryDrainResponse reads and throws away a small bounded amount of data from
// res.Body. This is a best-effort attempt to allow connection reuse. (Go's
// HTTP/1 Transport won't reuse a TCP connection unless you fully consume HTTP
Expand Down Expand Up @@ -115,6 +142,7 @@ func (c *HTTPClient) Get(ctx context.Context, actionID string) (outputID, diskPa
}
defer res.Body.Close()
defer tryDrainResponse(res)
c.noteServerCaps(res)
if res.StatusCode == http.StatusNotFound {
return "", "", nil
}
Expand Down Expand Up @@ -167,6 +195,7 @@ func (c *HTTPClient) Get(ctx context.Context, actionID string) (outputID, diskPa
}
defer res.Body.Close()
defer tryDrainResponse(res)
c.noteServerCaps(res)
if res.StatusCode == http.StatusNotFound {
return "", "", nil
}
Expand Down Expand Up @@ -214,8 +243,31 @@ func (c *HTTPClient) Put(ctx context.Context, actionID, outputID string, size in
}
}()

req, _ := http.NewRequestWithContext(ctx, "PUT", c.BaseURL+"/"+actionID+"/"+outputID, bytes.NewReader(buf))
req.ContentLength = size
var reqBody io.Reader = bytes.NewReader(buf)
reqContentLength := size
compressPut := size > 0 && c.serverHasCap(consts.CapPutLZ4)
if compressPut {
var cbuf bytes.Buffer
lzw := lz4.NewWriter(&cbuf)
if err := lzw.Apply(lz4.SizeOption(uint64(size))); err != nil {
return "", err
}
if _, err := lzw.Write(buf); err != nil {
return "", err
}
if err := lzw.Close(); err != nil {
return "", err
}
reqBody = bytes.NewReader(cbuf.Bytes())
reqContentLength = int64(cbuf.Len())
}

req, _ := http.NewRequestWithContext(ctx, "PUT", c.BaseURL+"/"+actionID+"/"+outputID, reqBody)
req.ContentLength = reqContentLength
if compressPut {
req.Header.Set("Content-Encoding", "lz4")
req.Header.Set("X-Uncompressed-Length", strconv.FormatInt(size, 10))
}
if c.AccessToken != "" {
req.Header.Set("Authorization", "Bearer "+c.AccessToken)
}
Expand All @@ -226,6 +278,7 @@ func (c *HTTPClient) Put(ctx context.Context, actionID, outputID string, size in
httpErr = err
} else {
defer res.Body.Close()
c.noteServerCaps(res)
if res.StatusCode != http.StatusNoContent {
msg := tryReadErrorMessage(res)
log.Printf("error PUT /%s/%s: %v, %s", actionID, outputID, res.Status, msg)
Expand Down
11 changes: 11 additions & 0 deletions consts/consts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Package consts defines shared constants used by both the gocached server
// and the cachers HTTP client.
package consts

// CapHeader is the HTTP response header used by the server to advertise
// its capabilities as a comma-separated list of tokens.
const CapHeader = "Gocached-Cap"

// CapPutLZ4 is the Gocached-Cap token indicating that the server accepts
// lz4-compressed PUT request bodies (Content-Encoding: lz4).
const CapPutLZ4 = "putlz4"
43 changes: 34 additions & 9 deletions gocached/gocached.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ import (
"path/filepath"
"reflect"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/bradfitz/go-tool-cache/consts"
ijwt "github.com/bradfitz/go-tool-cache/gocached/internal/jwt"
"github.com/pierrec/lz4/v4"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -563,6 +565,8 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
srv.logf("ServeHTTP: %s %s", r.Method, r.RequestURI)
}

w.Header().Set(consts.CapHeader, consts.CapPutLZ4)

var sessionData *sessionData // remains nil for unauthenticated requests.
reqStats := &stats{}
defer func() {
Expand Down Expand Up @@ -962,14 +966,37 @@ func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats)
return
}

// Determine the body reader and content size. If the client sent
// lz4-compressed data, wrap the body with a decompressor and use the
// uncompressed size for all downstream logic (hashing, inline vs disk
// decisions, etc.). The server will re-compress when writing to disk.
// The goal is reducing bytes on the wire, not saving CPU, so the
// decompress-then-recompress round trip is fine.
var bodyReader io.Reader = r.Body
contentSize := r.ContentLength
if r.Header.Get("Content-Encoding") == "lz4" {
sizeStr := r.Header.Get("X-Uncompressed-Length")
if sizeStr == "" {
http.Error(w, "lz4-compressed PUT missing X-Uncompressed-Length", http.StatusBadRequest)
return
}
var err error
contentSize, err = strconv.ParseInt(sizeStr, 10, 64)
if err != nil {
http.Error(w, "invalid X-Uncompressed-Length", http.StatusBadRequest)
return
}
bodyReader = lz4.NewReader(r.Body)
}

hasher := sha256.New()
hashingBody := io.TeeReader(r.Body, hasher)
hashingBody := io.TeeReader(bodyReader, hasher)

storedSize := r.ContentLength
uncompressedSize := r.ContentLength
storedSize := contentSize
uncompressedSize := contentSize

var smallData []byte
if r.ContentLength <= smallObjectSize {
if contentSize <= smallObjectSize {
// Store small objects inline in the database.
var err error
smallData, err = io.ReadAll(hashingBody)
Expand All @@ -978,15 +1005,13 @@ func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats)
http.Error(w, "Read content error", http.StatusInternalServerError)
return
}
if int64(len(smallData)) != r.ContentLength {
// This check is redundant with net/http's validation, but
// for extra clarity.
if int64(len(smallData)) != contentSize {
http.Error(w, "bad content length", http.StatusInternalServerError)
return
}
} else {
// For larger objects, we store them on disk (lz4 compressed).
diskSize, err := s.writeDiskBlob(r.ContentLength, hashingBody)
diskSize, err := s.writeDiskBlob(contentSize, hashingBody)
if err != nil {
s.logf("Write disk blob error: %v", err)
http.Error(w, "Write disk blob error", http.StatusInternalServerError)
Expand Down Expand Up @@ -1049,7 +1074,7 @@ func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats)
}

stats.Puts++
stats.PutsBytes += r.ContentLength
stats.PutsBytes += contentSize
if smallData != nil {
stats.PutsInline++
}
Expand Down
78 changes: 78 additions & 0 deletions gocached/gocached_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,84 @@ func TestLZ4Storage(t *testing.T) {
}
}

// roundTripFunc is an http.RoundTripper implemented as a function.
type roundTripFunc func(*http.Request) (*http.Response, error)

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

func TestCompressPuts(t *testing.T) {
st := newServerTester(t)

// Track whether PUT requests are sent with lz4 compression on the wire.
var putLZ4Count atomic.Int32
c := st.mkClient()
c.HTTPClient = &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.Method == "PUT" && req.Header.Get("Content-Encoding") == "lz4" {
putLZ4Count.Add(1)
}
return http.DefaultTransport.RoundTrip(req)
}),
}

// The first Put goes out uncompressed because the client hasn't
// yet seen a Gocached-Cap response. The Put response itself
// carries Gocached-Cap: putlz4, so subsequent puts are compressed.
st.wantPut(c, "0010", "9010", "first-put")
if n := putLZ4Count.Load(); n != 0 {
t.Fatalf("expected 0 lz4 puts before cap discovery, got %d", n)
}

// Now the client knows the server supports putlz4.
tests := []struct {
name string
val string
}{
{"empty", ""},
{"small_inline", "hello"},
{"inline_max", strings.Repeat("a", smallObjectSize)},
{"disk_no_lz4", strings.Repeat("b", smallObjectSize+1)},
{"disk_lz4_threshold", strings.Repeat("c", lz4CompressThreshold)},
{"disk_large", strings.Repeat("d", 4096)},
}

for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actionID := fmt.Sprintf("%04x", i+0x20)
outputID := fmt.Sprintf("%04x", i+0xa0)

st.wantPut(c, actionID, outputID, tt.val)

// Get from a fresh client to verify the server correctly
// decompressed the lz4 wire body and stored the original data.
cFresh := st.mkClient()
st.wantGet(cFresh, actionID, outputID, tt.val)
})
}

// All puts with size > 0 should have been lz4-compressed on the wire.
// "empty" has size=0 and is not compressed.
wantLZ4 := int32(len(tests) - 1) // all except "empty"
if n := putLZ4Count.Load(); n != wantLZ4 {
t.Errorf("expected %d lz4 puts after cap discovery, got %d", wantLZ4, n)
}

// Verify dedup: same content from a fresh client (no caps, sends
// uncompressed) and the cap-aware client (sends compressed) produces
// the same blob (same SHA-256).
cNoCaps := st.mkClient()
largeVal := strings.Repeat("e", 4096)

st.wantPut(cNoCaps, "aa01", "bb01", largeVal)
st.wantPut(c, "aa02", "bb01", largeVal)

cFresh := st.mkClient()
st.wantGet(cFresh, "aa01", "bb01", largeVal)
st.wantGet(cFresh, "aa02", "bb01", largeVal)
}

func TestClientConnReuse(t *testing.T) {
st := newServerTester(t)

Expand Down