Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions http/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,19 @@ func TraceInfoFromHeader(header HeaderGetter, allowedBaggage ...string) (rv Trac
traceParent := header.Get(traceParentHeader)
traceState := header.Get(traceStateHeader)
baggage := header.Get(baggageHeader)
bm := map[string]string{}
if baggage != "" {
for _, kv := range strings.Split(baggage, ",") {
if key, value, ok := strings.Cut(kv, "="); ok {
for _, b := range allowedBaggage {
if key == b {
bm[key] = value
break
}
}
}
}
}

if traceParent != "" {
parts := strings.Split(traceParent, "-")
Expand All @@ -91,19 +104,6 @@ func TraceInfoFromHeader(header HeaderGetter, allowedBaggage ...string) (rv Trac
return rv
}

bm := map[string]string{}
if baggage != "" {
for _, kv := range strings.Split(baggage, ",") {
if key, value, ok := strings.Cut(kv, "="); ok {
for _, b := range allowedBaggage {
if key == b {
bm[key] = value
break
}
}
}
}
}
return TraceInfo{
TraceId: &traceID,
ParentId: &parentID,
Expand All @@ -116,6 +116,7 @@ func TraceInfoFromHeader(header HeaderGetter, allowedBaggage ...string) (rv Trac
if strings.Contains(traceState, "sampled=true") {
return TraceInfo{
Sampled: true,
Baggage: bm,
}
}
return rv
Expand Down
263 changes: 263 additions & 0 deletions http/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
// Copyright (C) 2022 Storj Labs, Inc.
// See LICENSE for copying information.

package http

import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/spacemonkeygo/monkit/v3"
)

type TraceResponse struct {
TraceID string `json:"trace_id"`
SpanID string `json:"span_id"`
Annotations map[string]string `json:"annotations"`
}

func TestTraceHandlerIntegration(t *testing.T) {
scope := monkit.Package()

// Create a simple handler that returns trace information
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span := monkit.SpanFromCtx(r.Context())
if span == nil {
http.Error(w, "No span found in context", http.StatusInternalServerError)
return
}

annotations := make(map[string]string)
for _, annotation := range span.Annotations() {
annotations[annotation.Name] = annotation.Value
}

response := TraceResponse{
TraceID: fmt.Sprintf("%016x", span.Trace().Id()),
SpanID: fmt.Sprintf("%016x", span.Id()),
Annotations: annotations,
}

w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
})

// Wrap with TraceHandler
traceHandler := TraceHandler(handler, scope, "foo")

// Create test server
server := httptest.NewServer(traceHandler)
defer server.Close()

t.Run("propagate existing trace", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/test", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}

req.Header.Set("traceparent", "00-0000000000000001-00000002-01")
req.Header.Set("baggage", "foo=bar,forbidden=ignore")

traceResp := doRequest(t, err, req)

if traceResp.TraceID == "" {
t.Error("Expected trace ID to be present")
}

if traceResp.SpanID == "" {
t.Error("Expected span ID to be present")
}

if traceResp.Annotations["http.uri"] != "/test" {
t.Errorf("Expected http.uri annotation to be '/test', got '%s'", traceResp.Annotations["http.uri"])
}

if traceResp.Annotations["foo"] != "bar" {
t.Errorf("Annotation is missing")
}

if traceResp.Annotations["forbidden"] != "" {
t.Errorf("Annotation should be missing")
}
})

t.Run("orphan trace", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/test", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}

req.Header.Set("baggage", "foo=bar,forbidden=ignore")
req.Header.Set("tracestate", "sampled=true")

traceResp := doRequest(t, err, req)

if traceResp.TraceID == "" {
t.Error("Expected trace ID to be present")
}

if traceResp.SpanID == "" {
t.Error("Expected span ID to be present")
}

if traceResp.Annotations["http.uri"] != "/test" {
t.Errorf("Expected http.uri annotation to be '/test', got '%s'", traceResp.Annotations["http.uri"])
}

if traceResp.Annotations["foo"] != "bar" {
t.Errorf("Annotation is missing")
}

if traceResp.Annotations["forbidden"] != "" {
t.Errorf("Annotation should be missing")
}
})
}

func doRequest(t *testing.T, err error, req *http.Request) TraceResponse {
// Make request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}

// Parse response
var traceResp TraceResponse
if err := json.NewDecoder(resp.Body).Decode(&traceResp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
return traceResp
}

func TestTraceHandlerWithCustomBaggage(t *testing.T) {
scope := monkit.Package()

// Create handler that checks for baggage annotations
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span := monkit.SpanFromCtx(r.Context())
if span == nil {
http.Error(w, "No span found in context", http.StatusInternalServerError)
return
}

// The baggage should be added as annotations to the span
// We'll just return success - the baggage testing is indirect through span annotations
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})

// Create TraceHandler with allowed baggage
traceHandler := TraceHandler(handler, scope, "allowed-key", "another-allowed")

server := httptest.NewServer(traceHandler)
defer server.Close()

// Test with baggage header
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}

req.Header.Set("traceparent", "00-0000000000000001-00000002-01")
req.Header.Set("baggage", "allowed-key=allowed-value,not-allowed=ignored,another-allowed=another-value")

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}

// The test passes if the handler completes successfully
// Baggage is verified indirectly through the span annotations in the TraceHandler
}

func TestTraceHandlerContextPropagation(t *testing.T) {
scope := monkit.Package()

// Handler that verifies context propagation
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
span := monkit.SpanFromCtx(ctx)

if span == nil {
http.Error(w, "No span in context", http.StatusInternalServerError)
return
}

// Try to start a child span to verify context works
childCtx := ctx
defer scope.Func().RemoteTrace(&childCtx, span.Id(), span.Trace())(nil)

childSpan := monkit.SpanFromCtx(childCtx)
if childSpan == nil {
http.Error(w, "Could not create child span", http.StatusInternalServerError)
return
}

response := map[string]string{
"parent_trace_id": fmt.Sprintf("%016x", span.Trace().Id()),
"parent_span_id": fmt.Sprintf("%016x", span.Id()),
"child_trace_id": fmt.Sprintf("%016x", childSpan.Trace().Id()),
"child_span_id": fmt.Sprintf("%016x", childSpan.Id()),
}

w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
})

traceHandler := TraceHandler(handler, scope)
server := httptest.NewServer(traceHandler)
defer server.Close()

// Make request with trace parent
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("traceparent", "00-0000000000000001-00000002-01")

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}

var result map[string]string
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}

// Verify trace IDs match
if result["parent_trace_id"] != result["child_trace_id"] {
t.Errorf("Parent and child should share same trace ID: parent=%s, child=%s",
result["parent_trace_id"], result["child_trace_id"])
}

// Verify span IDs are different
if result["parent_span_id"] == result["child_span_id"] {
t.Error("Parent and child should have different span IDs")
}

// Verify expected trace ID
if result["parent_trace_id"] != "0000000000000001" {
t.Errorf("Expected trace ID 0000000000000001, got %s", result["parent_trace_id"])
}
}
Loading