diff --git a/pkg/metrics/handler_method.go b/pkg/metrics/handler_method.go index 9f737e1..419c5b3 100644 --- a/pkg/metrics/handler_method.go +++ b/pkg/metrics/handler_method.go @@ -1,6 +1,8 @@ package metrics import ( + "bufio" + "net" "net/http" "time" @@ -152,3 +154,13 @@ func (rw *responseWriter) WriteHeader(statusCode int) { rw.statusCode = statusCode rw.ResponseWriter.WriteHeader(statusCode) } + +// Hijack delegates to the underlying ResponseWriter when it supports http.Hijacker. +// This is required for WebSocket upgrades to work through the metrics wrapper. +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := rw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, http.ErrNotSupported + } + return h.Hijack() +} diff --git a/pkg/metrics/metrics_test.go b/pkg/metrics/metrics_test.go index 73e6661..6737fd0 100644 --- a/pkg/metrics/metrics_test.go +++ b/pkg/metrics/metrics_test.go @@ -1,7 +1,10 @@ package metrics import ( + "bufio" + "errors" "math/rand" + "net" "net/http" "testing" "time" @@ -814,6 +817,32 @@ func NewMockResponseWriter() *MockResponseWriter { } } +// MockHijackResponseWriter is a mock ResponseWriter that supports hijacking. +type MockHijackResponseWriter struct { + *MockResponseWriter + hijackCalled bool + conn net.Conn + peer net.Conn + buf *bufio.ReadWriter +} + +// NewMockHijackResponseWriter creates a new MockHijackResponseWriter. +func NewMockHijackResponseWriter() *MockHijackResponseWriter { + conn, peer := net.Pipe() + return &MockHijackResponseWriter{ + MockResponseWriter: NewMockResponseWriter(), + conn: conn, + peer: peer, + buf: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), + } +} + +// Hijack records the call and returns the mocked connection and buffer. +func (w *MockHijackResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + w.hijackCalled = true + return w.conn, w.buf, nil +} + // TestMetricsMiddlewareImpl_Handler tests the Handler method of MetricsMiddlewareImpl func TestMetricsMiddlewareImpl_Handler(t *testing.T) { // Create a mock registry @@ -1181,3 +1210,47 @@ func TestResponseWriter(t *testing.T) { t.Errorf("Expected underlying data %q, got %q", "Test data", string(mockRw.writtenData)) } } + +// TestResponseWriterHijack ensures hijack calls are forwarded. +func TestResponseWriterHijack(t *testing.T) { + mockRw := NewMockHijackResponseWriter() + t.Cleanup(func() { + _ = mockRw.conn.Close() + _ = mockRw.peer.Close() + }) + + rw := &responseWriter{ + ResponseWriter: mockRw, + statusCode: http.StatusOK, + } + + conn, buf, err := rw.Hijack() + if err != nil { + t.Fatalf("Expected hijack to succeed, got error: %v", err) + } + if !mockRw.hijackCalled { + t.Error("Expected Hijack to be called on the underlying ResponseWriter") + } + if conn != mockRw.conn { + t.Errorf("Expected hijack connection to match, got %v", conn) + } + if buf != mockRw.buf { + t.Error("Expected hijack read-writer to match") + } +} + +// TestResponseWriterHijackUnsupported ensures unsupported hijacks return ErrNotSupported. +func TestResponseWriterHijackUnsupported(t *testing.T) { + rw := &responseWriter{ + ResponseWriter: NewMockResponseWriter(), + statusCode: http.StatusOK, + } + + conn, buf, err := rw.Hijack() + if !errors.Is(err, http.ErrNotSupported) { + t.Fatalf("Expected ErrNotSupported, got %v", err) + } + if conn != nil || buf != nil { + t.Error("Expected nil connection and buffer on unsupported hijack") + } +}