Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions internal/strategy/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"maps"
"net/http"
"os"
"sort"
"strings"
"time"

"github.com/alecthomas/errors"
Expand All @@ -14,6 +16,38 @@ import (
"github.com/block/cachew/internal/logging"
)

// CacheKeyParts holds the components used to build a cache key. Path is the
// primary identifier (typically the upstream URL) and Vary captures
// request-derived dimensions like Accept-Encoding.
type CacheKeyParts struct {
Path string
Vary map[string]string
}

func NewCacheKeyParts(path string) CacheKeyParts {
return CacheKeyParts{Path: path, Vary: make(map[string]string)}
}

func (p CacheKeyParts) Key() cache.Key {
if len(p.Vary) == 0 {
return cache.NewKey(p.Path)
}
keys := make([]string, 0, len(p.Vary))
for k := range p.Vary {
keys = append(keys, k)
}
sort.Strings(keys)
var b strings.Builder
b.WriteString(p.Path)
for _, k := range keys {
b.WriteByte('\n')
b.WriteString(k)
b.WriteByte('=')
b.WriteString(p.Vary[k])
}
return cache.NewKey(b.String())
}

// Handler provides a fluent API for creating cache-backed HTTP handlers.
//
// Example usage:
Expand All @@ -29,7 +63,7 @@ import (
type Handler struct {
client *http.Client
cache cache.Cache
cacheKeyFunc func(*http.Request) string
cacheKeyFunc func(*http.Request) CacheKeyParts
transformFunc func(*http.Request) (*http.Request, error)
errorHandler func(error, http.ResponseWriter, *http.Request)
ttlFunc func(*http.Request) time.Duration
Expand All @@ -44,8 +78,8 @@ func New(client *http.Client, c cache.Cache) *Handler {
return &Handler{
client: client,
cache: c,
cacheKeyFunc: func(r *http.Request) string {
return r.URL.String()
cacheKeyFunc: func(r *http.Request) CacheKeyParts {
return NewCacheKeyParts(r.URL.String())
},
transformFunc: func(r *http.Request) (*http.Request, error) {
return r, nil
Expand All @@ -60,7 +94,9 @@ func New(client *http.Client, c cache.Cache) *Handler {
// CacheKey sets the function used to determine the cache key for a request.
// The function receives the original incoming request.
func (h *Handler) CacheKey(f func(*http.Request) string) *Handler {
h.cacheKeyFunc = f
h.cacheKeyFunc = func(r *http.Request) CacheKeyParts {
return NewCacheKeyParts(f(r))
}
return h
}

Expand Down Expand Up @@ -100,10 +136,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

logger := logging.FromContext(ctx)

cacheKeyStr := h.cacheKeyFunc(r)
key := cache.NewKey(cacheKeyStr)
parts := h.cacheKeyFunc(r)
// Upstreams may return different content based on Accept-Encoding (e.g.
// gzip-compressed vs uncompressed). Without this, the first variant cached
// is served to all clients, breaking those that expect the other encoding.
if ae := r.Header.Get("Accept-Encoding"); ae != "" {
parts.Vary["Accept-Encoding"] = ae
}
Comment thread
joshfriend marked this conversation as resolved.
key := parts.Key()

logger.DebugContext(ctx, "Processing request", "cache_key", cacheKeyStr)
logger.DebugContext(ctx, "Processing request", "cache_key", parts.Path)

served, err := h.serveCached(w, r, key)
if err != nil {
Expand Down
37 changes: 37 additions & 0 deletions internal/strategy/handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,43 @@ func TestHeaderForwarding(t *testing.T) {
assert.Equal(t, "", receivedHeaders.Get("Keep-Alive"))
})

t.Run("AcceptEncodingVariesCacheKey", func(t *testing.T) {
callCount := 0
varyUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Vary", "Accept-Encoding")
_, _ = fmt.Fprintf(w, "call %d ae=%s", callCount, r.Header.Get("Accept-Encoding"))
}))
defer varyUpstream.Close()

varyCache := mustNewMemoryCache()
h := handler.New(http.DefaultClient, varyCache).
Transform(func(r *http.Request) (*http.Request, error) {
return http.NewRequestWithContext(r.Context(), http.MethodGet, varyUpstream.URL+"/file.zip", nil)
})

// Request without Accept-Encoding
r1 := httptest.NewRequest(http.MethodGet, "http://example.com/file.zip", nil)
r1 = r1.WithContext(ctx)
w1 := httptest.NewRecorder()
h.ServeHTTP(w1, r1)
assert.Equal(t, http.StatusOK, w1.Code)
body1 := w1.Body.String()

// Request with Accept-Encoding: gzip — should be a separate cache entry
r2 := httptest.NewRequest(http.MethodGet, "http://example.com/file.zip", nil)
r2 = r2.WithContext(ctx)
r2.Header.Set("Accept-Encoding", "gzip")
w2 := httptest.NewRecorder()
h.ServeHTTP(w2, r2)
assert.Equal(t, http.StatusOK, w2.Code)
body2 := w2.Body.String()

// Both should have hit upstream (different cache keys)
assert.Equal(t, 2, callCount)
assert.NotEqual(t, body1, body2)
})

t.Run("TransformHeadersTakePrecedence", func(t *testing.T) {
h := handler.New(http.DefaultClient, c).
CacheKey(func(_ *http.Request) string { return "fwd-test-3" }).
Expand Down