Skip to content

Commit 41d9e05

Browse files
authored
feat: add Hijack support for WebSocket upgrades in metrics handler (#106)
1 parent 2efa456 commit 41d9e05

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

pkg/metrics/handler_method.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package metrics
22

33
import (
4+
"bufio"
5+
"net"
46
"net/http"
57
"time"
68

@@ -152,3 +154,13 @@ func (rw *responseWriter) WriteHeader(statusCode int) {
152154
rw.statusCode = statusCode
153155
rw.ResponseWriter.WriteHeader(statusCode)
154156
}
157+
158+
// Hijack delegates to the underlying ResponseWriter when it supports http.Hijacker.
159+
// This is required for WebSocket upgrades to work through the metrics wrapper.
160+
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
161+
h, ok := rw.ResponseWriter.(http.Hijacker)
162+
if !ok {
163+
return nil, nil, http.ErrNotSupported
164+
}
165+
return h.Hijack()
166+
}

pkg/metrics/metrics_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package metrics
22

33
import (
4+
"bufio"
5+
"errors"
46
"math/rand"
7+
"net"
58
"net/http"
69
"testing"
710
"time"
@@ -814,6 +817,32 @@ func NewMockResponseWriter() *MockResponseWriter {
814817
}
815818
}
816819

820+
// MockHijackResponseWriter is a mock ResponseWriter that supports hijacking.
821+
type MockHijackResponseWriter struct {
822+
*MockResponseWriter
823+
hijackCalled bool
824+
conn net.Conn
825+
peer net.Conn
826+
buf *bufio.ReadWriter
827+
}
828+
829+
// NewMockHijackResponseWriter creates a new MockHijackResponseWriter.
830+
func NewMockHijackResponseWriter() *MockHijackResponseWriter {
831+
conn, peer := net.Pipe()
832+
return &MockHijackResponseWriter{
833+
MockResponseWriter: NewMockResponseWriter(),
834+
conn: conn,
835+
peer: peer,
836+
buf: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
837+
}
838+
}
839+
840+
// Hijack records the call and returns the mocked connection and buffer.
841+
func (w *MockHijackResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
842+
w.hijackCalled = true
843+
return w.conn, w.buf, nil
844+
}
845+
817846
// TestMetricsMiddlewareImpl_Handler tests the Handler method of MetricsMiddlewareImpl
818847
func TestMetricsMiddlewareImpl_Handler(t *testing.T) {
819848
// Create a mock registry
@@ -1181,3 +1210,47 @@ func TestResponseWriter(t *testing.T) {
11811210
t.Errorf("Expected underlying data %q, got %q", "Test data", string(mockRw.writtenData))
11821211
}
11831212
}
1213+
1214+
// TestResponseWriterHijack ensures hijack calls are forwarded.
1215+
func TestResponseWriterHijack(t *testing.T) {
1216+
mockRw := NewMockHijackResponseWriter()
1217+
t.Cleanup(func() {
1218+
_ = mockRw.conn.Close()
1219+
_ = mockRw.peer.Close()
1220+
})
1221+
1222+
rw := &responseWriter{
1223+
ResponseWriter: mockRw,
1224+
statusCode: http.StatusOK,
1225+
}
1226+
1227+
conn, buf, err := rw.Hijack()
1228+
if err != nil {
1229+
t.Fatalf("Expected hijack to succeed, got error: %v", err)
1230+
}
1231+
if !mockRw.hijackCalled {
1232+
t.Error("Expected Hijack to be called on the underlying ResponseWriter")
1233+
}
1234+
if conn != mockRw.conn {
1235+
t.Errorf("Expected hijack connection to match, got %v", conn)
1236+
}
1237+
if buf != mockRw.buf {
1238+
t.Error("Expected hijack read-writer to match")
1239+
}
1240+
}
1241+
1242+
// TestResponseWriterHijackUnsupported ensures unsupported hijacks return ErrNotSupported.
1243+
func TestResponseWriterHijackUnsupported(t *testing.T) {
1244+
rw := &responseWriter{
1245+
ResponseWriter: NewMockResponseWriter(),
1246+
statusCode: http.StatusOK,
1247+
}
1248+
1249+
conn, buf, err := rw.Hijack()
1250+
if !errors.Is(err, http.ErrNotSupported) {
1251+
t.Fatalf("Expected ErrNotSupported, got %v", err)
1252+
}
1253+
if conn != nil || buf != nil {
1254+
t.Error("Expected nil connection and buffer on unsupported hijack")
1255+
}
1256+
}

0 commit comments

Comments
 (0)