diff --git a/http/info.go b/http/info.go index 58c4e95..cb06739 100644 --- a/http/info.go +++ b/http/info.go @@ -68,6 +68,19 @@ func TraceInfoFromHeader(header HeaderGetter, allowedBaggage ...string) (rv Trac traceParent := header.Get(traceParentHeader) traceState := header.Get(traceStateHeader) baggage := header.Get(baggageHeader) + bm := map[string]string{} + if baggage != "" { + for _, kv := range strings.Split(baggage, ",") { + if key, value, ok := strings.Cut(kv, "="); ok { + for _, b := range allowedBaggage { + if key == b { + bm[key] = value + break + } + } + } + } + } if traceParent != "" { parts := strings.Split(traceParent, "-") @@ -91,19 +104,6 @@ func TraceInfoFromHeader(header HeaderGetter, allowedBaggage ...string) (rv Trac return rv } - bm := map[string]string{} - if baggage != "" { - for _, kv := range strings.Split(baggage, ",") { - if key, value, ok := strings.Cut(kv, "="); ok { - for _, b := range allowedBaggage { - if key == b { - bm[key] = value - break - } - } - } - } - } return TraceInfo{ TraceId: &traceID, ParentId: &parentID, @@ -116,6 +116,7 @@ func TraceInfoFromHeader(header HeaderGetter, allowedBaggage ...string) (rv Trac if strings.Contains(traceState, "sampled=true") { return TraceInfo{ Sampled: true, + Baggage: bm, } } return rv diff --git a/http/server_test.go b/http/server_test.go new file mode 100644 index 0000000..f34cb69 --- /dev/null +++ b/http/server_test.go @@ -0,0 +1,263 @@ +// Copyright (C) 2022 Storj Labs, Inc. +// See LICENSE for copying information. + +package http + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/spacemonkeygo/monkit/v3" +) + +type TraceResponse struct { + TraceID string `json:"trace_id"` + SpanID string `json:"span_id"` + Annotations map[string]string `json:"annotations"` +} + +func TestTraceHandlerIntegration(t *testing.T) { + scope := monkit.Package() + + // Create a simple handler that returns trace information + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + span := monkit.SpanFromCtx(r.Context()) + if span == nil { + http.Error(w, "No span found in context", http.StatusInternalServerError) + return + } + + annotations := make(map[string]string) + for _, annotation := range span.Annotations() { + annotations[annotation.Name] = annotation.Value + } + + response := TraceResponse{ + TraceID: fmt.Sprintf("%016x", span.Trace().Id()), + SpanID: fmt.Sprintf("%016x", span.Id()), + Annotations: annotations, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + }) + + // Wrap with TraceHandler + traceHandler := TraceHandler(handler, scope, "foo") + + // Create test server + server := httptest.NewServer(traceHandler) + defer server.Close() + + t.Run("propagate existing trace", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/test", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + req.Header.Set("traceparent", "00-0000000000000001-00000002-01") + req.Header.Set("baggage", "foo=bar,forbidden=ignore") + + traceResp := doRequest(t, err, req) + + if traceResp.TraceID == "" { + t.Error("Expected trace ID to be present") + } + + if traceResp.SpanID == "" { + t.Error("Expected span ID to be present") + } + + if traceResp.Annotations["http.uri"] != "/test" { + t.Errorf("Expected http.uri annotation to be '/test', got '%s'", traceResp.Annotations["http.uri"]) + } + + if traceResp.Annotations["foo"] != "bar" { + t.Errorf("Annotation is missing") + } + + if traceResp.Annotations["forbidden"] != "" { + t.Errorf("Annotation should be missing") + } + }) + + t.Run("orphan trace", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/test", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + req.Header.Set("baggage", "foo=bar,forbidden=ignore") + req.Header.Set("tracestate", "sampled=true") + + traceResp := doRequest(t, err, req) + + if traceResp.TraceID == "" { + t.Error("Expected trace ID to be present") + } + + if traceResp.SpanID == "" { + t.Error("Expected span ID to be present") + } + + if traceResp.Annotations["http.uri"] != "/test" { + t.Errorf("Expected http.uri annotation to be '/test', got '%s'", traceResp.Annotations["http.uri"]) + } + + if traceResp.Annotations["foo"] != "bar" { + t.Errorf("Annotation is missing") + } + + if traceResp.Annotations["forbidden"] != "" { + t.Errorf("Annotation should be missing") + } + }) +} + +func doRequest(t *testing.T, err error, req *http.Request) TraceResponse { + // Make request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected status 200, got %d", resp.StatusCode) + } + + // Parse response + var traceResp TraceResponse + if err := json.NewDecoder(resp.Body).Decode(&traceResp); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + return traceResp +} + +func TestTraceHandlerWithCustomBaggage(t *testing.T) { + scope := monkit.Package() + + // Create handler that checks for baggage annotations + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + span := monkit.SpanFromCtx(r.Context()) + if span == nil { + http.Error(w, "No span found in context", http.StatusInternalServerError) + return + } + + // The baggage should be added as annotations to the span + // We'll just return success - the baggage testing is indirect through span annotations + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + // Create TraceHandler with allowed baggage + traceHandler := TraceHandler(handler, scope, "allowed-key", "another-allowed") + + server := httptest.NewServer(traceHandler) + defer server.Close() + + // Test with baggage header + req, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + req.Header.Set("traceparent", "00-0000000000000001-00000002-01") + req.Header.Set("baggage", "allowed-key=allowed-value,not-allowed=ignored,another-allowed=another-value") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected status 200, got %d", resp.StatusCode) + } + + // The test passes if the handler completes successfully + // Baggage is verified indirectly through the span annotations in the TraceHandler +} + +func TestTraceHandlerContextPropagation(t *testing.T) { + scope := monkit.Package() + + // Handler that verifies context propagation + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + span := monkit.SpanFromCtx(ctx) + + if span == nil { + http.Error(w, "No span in context", http.StatusInternalServerError) + return + } + + // Try to start a child span to verify context works + childCtx := ctx + defer scope.Func().RemoteTrace(&childCtx, span.Id(), span.Trace())(nil) + + childSpan := monkit.SpanFromCtx(childCtx) + if childSpan == nil { + http.Error(w, "Could not create child span", http.StatusInternalServerError) + return + } + + response := map[string]string{ + "parent_trace_id": fmt.Sprintf("%016x", span.Trace().Id()), + "parent_span_id": fmt.Sprintf("%016x", span.Id()), + "child_trace_id": fmt.Sprintf("%016x", childSpan.Trace().Id()), + "child_span_id": fmt.Sprintf("%016x", childSpan.Id()), + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + }) + + traceHandler := TraceHandler(handler, scope) + server := httptest.NewServer(traceHandler) + defer server.Close() + + // Make request with trace parent + req, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("traceparent", "00-0000000000000001-00000002-01") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected status 200, got %d", resp.StatusCode) + } + + var result map[string]string + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Verify trace IDs match + if result["parent_trace_id"] != result["child_trace_id"] { + t.Errorf("Parent and child should share same trace ID: parent=%s, child=%s", + result["parent_trace_id"], result["child_trace_id"]) + } + + // Verify span IDs are different + if result["parent_span_id"] == result["child_span_id"] { + t.Error("Parent and child should have different span IDs") + } + + // Verify expected trace ID + if result["parent_trace_id"] != "0000000000000001" { + t.Errorf("Expected trace ID 0000000000000001, got %s", result["parent_trace_id"]) + } +}