Skip to content

Commit b68719f

Browse files
feat: add AutoHead feature to automatically register HEAD routes for GET requests
1 parent 0143b9d commit b68719f

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

echo.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ import (
5454
"os"
5555
"os/signal"
5656
"path/filepath"
57+
"strconv"
5758
"strings"
5859
"sync"
5960
"sync/atomic"
@@ -100,6 +101,7 @@ type Echo struct {
100101

101102
// formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm)
102103
formParseMaxMemory int64
104+
AutoHead bool
103105
}
104106

105107
// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces.
@@ -288,6 +290,11 @@ type Config struct {
288290
// FormParseMaxMemory is default value for memory limit that is used
289291
// when parsing multipart forms (See (*http.Request).ParseMultipartForm)
290292
FormParseMaxMemory int64
293+
294+
// AutoHead enables automatic registration of HEAD routes for GET routes.
295+
// When enabled, a HEAD request to a GET-only path will be handled automatically
296+
// using the same handler as GET, with the response body suppressed.
297+
AutoHead bool
291298
}
292299

293300
// NewWithConfig creates an instance of Echo with given configuration.
@@ -326,6 +333,9 @@ func NewWithConfig(config Config) *Echo {
326333
if config.FormParseMaxMemory > 0 {
327334
e.formParseMaxMemory = config.FormParseMaxMemory
328335
}
336+
if config.AutoHead {
337+
e.AutoHead = config.AutoHead
338+
}
329339
return e
330340
}
331341

@@ -421,6 +431,67 @@ func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler {
421431
}
422432
}
423433

434+
// headResponseWriter wraps an http.ResponseWriter and suppresses the response body
435+
// while preserving headers and status code. Used for automatic HEAD route handling.
436+
// It counts the bytes that would have been written so we can set Content-Length accurately.
437+
type headResponseWriter struct {
438+
http.ResponseWriter
439+
bytesWritten int64
440+
statusCode int
441+
wroteHeader bool
442+
}
443+
444+
// Write intercepts writes to the response body and counts bytes without actually writing them.
445+
func (hw *headResponseWriter) Write(b []byte) (int, error) {
446+
if !hw.wroteHeader {
447+
hw.statusCode = http.StatusOK
448+
hw.wroteHeader = true
449+
}
450+
hw.bytesWritten += int64(len(b))
451+
// Return success without actually writing the body for HEAD requests
452+
return len(b), nil
453+
}
454+
455+
// WriteHeader intercepts the status code but still writes it to the underlying ResponseWriter.
456+
func (hw *headResponseWriter) WriteHeader(statusCode int) {
457+
if !hw.wroteHeader {
458+
hw.statusCode = statusCode
459+
hw.wroteHeader = true
460+
hw.ResponseWriter.WriteHeader(statusCode)
461+
}
462+
}
463+
464+
// Unwrap returns the underlying http.ResponseWriter for compatibility with echo.Response unwrapping.
465+
func (hw *headResponseWriter) Unwrap() http.ResponseWriter {
466+
return hw.ResponseWriter
467+
}
468+
469+
func wrapHeadHandler(handler HandlerFunc) HandlerFunc {
470+
return func(c *Context) error {
471+
if c.Request().Method != http.MethodHead {
472+
return handler(c)
473+
}
474+
originalWriter := c.Response()
475+
headWriter := &headResponseWriter{ResponseWriter: originalWriter}
476+
477+
c.SetResponse(headWriter)
478+
defer func() {
479+
c.SetResponse(originalWriter)
480+
}()
481+
err := handler(c)
482+
483+
if headWriter.bytesWritten > 0 {
484+
originalWriter.Header().Set("Content-Length", strconv.FormatInt(headWriter.bytesWritten, 10))
485+
}
486+
487+
if !headWriter.wroteHeader && headWriter.statusCode > 0 {
488+
originalWriter.WriteHeader(headWriter.statusCode)
489+
}
490+
491+
return err
492+
}
493+
}
494+
424495
// Pre adds middleware to the chain which is run before router tries to find matching route.
425496
// Meaning middleware is executed even for 404 (not found) cases.
426497
func (e *Echo) Pre(middleware ...MiddlewareFunc) {
@@ -634,6 +705,20 @@ func (e *Echo) add(route Route) (RouteInfo, error) {
634705
if paramsCount > e.contextPathParamAllocSize.Load() {
635706
e.contextPathParamAllocSize.Store(paramsCount)
636707
}
708+
709+
// Auto-register HEAD route for GET if AutoHead is enabled
710+
if e.AutoHead && route.Method == http.MethodGet {
711+
headRoute := Route{
712+
Method: http.MethodHead,
713+
Path: route.Path,
714+
Handler: wrapHeadHandler(route.Handler),
715+
Middlewares: route.Middlewares,
716+
Name: route.Name,
717+
}
718+
// Attempt to add HEAD route, but ignore errors if an explicit HEAD route already exists
719+
_, _ = e.router.Add(headRoute)
720+
}
721+
637722
return ri, nil
638723
}
639724

@@ -642,6 +727,7 @@ func (e *Echo) add(route Route) (RouteInfo, error) {
642727
func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
643728
ri, err := e.add(
644729
Route{
730+
645731
Method: method,
646732
Path: path,
647733
Handler: handler,

0 commit comments

Comments
 (0)