From 17268ccc24da743702e6e84c8cc7f0a2fc87719a Mon Sep 17 00:00:00 2001 From: Josh Friend Date: Fri, 15 May 2026 11:07:13 -0400 Subject: [PATCH] fix(handler): vary cache key by Accept-Encoding to prevent cache poisoning Introduce CacheKeyParts to cleanly separate the cache key path from request-derived dimensions like Accept-Encoding. Upstreams like dl.google.com return different content based on Accept-Encoding, and without this the first variant cached gets served to all clients. --- internal/strategy/handler/handler.go | 56 ++++++++++++++++++++--- internal/strategy/handler/handler_test.go | 37 +++++++++++++++ 2 files changed, 86 insertions(+), 7 deletions(-) diff --git a/internal/strategy/handler/handler.go b/internal/strategy/handler/handler.go index 029e6fe..7079876 100644 --- a/internal/strategy/handler/handler.go +++ b/internal/strategy/handler/handler.go @@ -5,6 +5,8 @@ import ( "maps" "net/http" "os" + "sort" + "strings" "time" "github.com/alecthomas/errors" @@ -14,6 +16,38 @@ import ( "github.com/block/cachew/internal/logging" ) +// CacheKeyParts holds the components used to build a cache key. Path is the +// primary identifier (typically the upstream URL) and Vary captures +// request-derived dimensions like Accept-Encoding. +type CacheKeyParts struct { + Path string + Vary map[string]string +} + +func NewCacheKeyParts(path string) CacheKeyParts { + return CacheKeyParts{Path: path, Vary: make(map[string]string)} +} + +func (p CacheKeyParts) Key() cache.Key { + if len(p.Vary) == 0 { + return cache.NewKey(p.Path) + } + keys := make([]string, 0, len(p.Vary)) + for k := range p.Vary { + keys = append(keys, k) + } + sort.Strings(keys) + var b strings.Builder + b.WriteString(p.Path) + for _, k := range keys { + b.WriteByte('\n') + b.WriteString(k) + b.WriteByte('=') + b.WriteString(p.Vary[k]) + } + return cache.NewKey(b.String()) +} + // Handler provides a fluent API for creating cache-backed HTTP handlers. // // Example usage: @@ -29,7 +63,7 @@ import ( type Handler struct { client *http.Client cache cache.Cache - cacheKeyFunc func(*http.Request) string + cacheKeyFunc func(*http.Request) CacheKeyParts transformFunc func(*http.Request) (*http.Request, error) errorHandler func(error, http.ResponseWriter, *http.Request) ttlFunc func(*http.Request) time.Duration @@ -44,8 +78,8 @@ func New(client *http.Client, c cache.Cache) *Handler { return &Handler{ client: client, cache: c, - cacheKeyFunc: func(r *http.Request) string { - return r.URL.String() + cacheKeyFunc: func(r *http.Request) CacheKeyParts { + return NewCacheKeyParts(r.URL.String()) }, transformFunc: func(r *http.Request) (*http.Request, error) { return r, nil @@ -60,7 +94,9 @@ func New(client *http.Client, c cache.Cache) *Handler { // CacheKey sets the function used to determine the cache key for a request. // The function receives the original incoming request. func (h *Handler) CacheKey(f func(*http.Request) string) *Handler { - h.cacheKeyFunc = f + h.cacheKeyFunc = func(r *http.Request) CacheKeyParts { + return NewCacheKeyParts(f(r)) + } return h } @@ -100,10 +136,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { logger := logging.FromContext(ctx) - cacheKeyStr := h.cacheKeyFunc(r) - key := cache.NewKey(cacheKeyStr) + parts := h.cacheKeyFunc(r) + // Upstreams may return different content based on Accept-Encoding (e.g. + // gzip-compressed vs uncompressed). Without this, the first variant cached + // is served to all clients, breaking those that expect the other encoding. + if ae := r.Header.Get("Accept-Encoding"); ae != "" { + parts.Vary["Accept-Encoding"] = ae + } + key := parts.Key() - logger.DebugContext(ctx, "Processing request", "cache_key", cacheKeyStr) + logger.DebugContext(ctx, "Processing request", "cache_key", parts.Path) served, err := h.serveCached(w, r, key) if err != nil { diff --git a/internal/strategy/handler/handler_test.go b/internal/strategy/handler/handler_test.go index 4c43026..83d6faa 100644 --- a/internal/strategy/handler/handler_test.go +++ b/internal/strategy/handler/handler_test.go @@ -331,6 +331,43 @@ func TestHeaderForwarding(t *testing.T) { assert.Equal(t, "", receivedHeaders.Get("Keep-Alive")) }) + t.Run("AcceptEncodingVariesCacheKey", func(t *testing.T) { + callCount := 0 + varyUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Vary", "Accept-Encoding") + _, _ = fmt.Fprintf(w, "call %d ae=%s", callCount, r.Header.Get("Accept-Encoding")) + })) + defer varyUpstream.Close() + + varyCache := mustNewMemoryCache() + h := handler.New(http.DefaultClient, varyCache). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, varyUpstream.URL+"/file.zip", nil) + }) + + // Request without Accept-Encoding + r1 := httptest.NewRequest(http.MethodGet, "http://example.com/file.zip", nil) + r1 = r1.WithContext(ctx) + w1 := httptest.NewRecorder() + h.ServeHTTP(w1, r1) + assert.Equal(t, http.StatusOK, w1.Code) + body1 := w1.Body.String() + + // Request with Accept-Encoding: gzip — should be a separate cache entry + r2 := httptest.NewRequest(http.MethodGet, "http://example.com/file.zip", nil) + r2 = r2.WithContext(ctx) + r2.Header.Set("Accept-Encoding", "gzip") + w2 := httptest.NewRecorder() + h.ServeHTTP(w2, r2) + assert.Equal(t, http.StatusOK, w2.Code) + body2 := w2.Body.String() + + // Both should have hit upstream (different cache keys) + assert.Equal(t, 2, callCount) + assert.NotEqual(t, body1, body2) + }) + t.Run("TransformHeadersTakePrecedence", func(t *testing.T) { h := handler.New(http.DefaultClient, c). CacheKey(func(_ *http.Request) string { return "fwd-test-3" }).