Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pkg/metrics/handler_method.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package metrics

import (
"bufio"
"net"
"net/http"
"time"

Expand Down Expand Up @@ -152,3 +154,13 @@ func (rw *responseWriter) WriteHeader(statusCode int) {
rw.statusCode = statusCode
rw.ResponseWriter.WriteHeader(statusCode)
}

// Hijack delegates to the underlying ResponseWriter when it supports http.Hijacker.
// This is required for WebSocket upgrades to work through the metrics wrapper.
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, http.ErrNotSupported
}
return h.Hijack()
}
73 changes: 73 additions & 0 deletions pkg/metrics/metrics_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package metrics

import (
"bufio"
"errors"
"math/rand"
"net"
"net/http"
"testing"
"time"
Expand Down Expand Up @@ -814,6 +817,32 @@ func NewMockResponseWriter() *MockResponseWriter {
}
}

// MockHijackResponseWriter is a mock ResponseWriter that supports hijacking.
type MockHijackResponseWriter struct {
*MockResponseWriter
hijackCalled bool
conn net.Conn
peer net.Conn
buf *bufio.ReadWriter
}

// NewMockHijackResponseWriter creates a new MockHijackResponseWriter.
func NewMockHijackResponseWriter() *MockHijackResponseWriter {
conn, peer := net.Pipe()
return &MockHijackResponseWriter{
MockResponseWriter: NewMockResponseWriter(),
conn: conn,
peer: peer,
buf: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
}
}

// Hijack records the call and returns the mocked connection and buffer.
func (w *MockHijackResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
w.hijackCalled = true
return w.conn, w.buf, nil
}

// TestMetricsMiddlewareImpl_Handler tests the Handler method of MetricsMiddlewareImpl
func TestMetricsMiddlewareImpl_Handler(t *testing.T) {
// Create a mock registry
Expand Down Expand Up @@ -1181,3 +1210,47 @@ func TestResponseWriter(t *testing.T) {
t.Errorf("Expected underlying data %q, got %q", "Test data", string(mockRw.writtenData))
}
}

// TestResponseWriterHijack ensures hijack calls are forwarded.
func TestResponseWriterHijack(t *testing.T) {
mockRw := NewMockHijackResponseWriter()
t.Cleanup(func() {
_ = mockRw.conn.Close()
_ = mockRw.peer.Close()
})

rw := &responseWriter{
ResponseWriter: mockRw,
statusCode: http.StatusOK,
}

conn, buf, err := rw.Hijack()
if err != nil {
t.Fatalf("Expected hijack to succeed, got error: %v", err)
}
if !mockRw.hijackCalled {
t.Error("Expected Hijack to be called on the underlying ResponseWriter")
}
if conn != mockRw.conn {
t.Errorf("Expected hijack connection to match, got %v", conn)
}
if buf != mockRw.buf {
t.Error("Expected hijack read-writer to match")
}
}

// TestResponseWriterHijackUnsupported ensures unsupported hijacks return ErrNotSupported.
func TestResponseWriterHijackUnsupported(t *testing.T) {
rw := &responseWriter{
ResponseWriter: NewMockResponseWriter(),
statusCode: http.StatusOK,
}

conn, buf, err := rw.Hijack()
if !errors.Is(err, http.ErrNotSupported) {
t.Fatalf("Expected ErrNotSupported, got %v", err)
}
if conn != nil || buf != nil {
t.Error("Expected nil connection and buffer on unsupported hijack")
}
}