diff --git a/internal/beater/monitoringtest/opentelemetry.go b/internal/beater/monitoringtest/opentelemetry.go index 9b643f3d171..271399e4138 100644 --- a/internal/beater/monitoringtest/opentelemetry.go +++ b/internal/beater/monitoringtest/opentelemetry.go @@ -42,6 +42,18 @@ func ExpectContainOtelMetrics( assertOtelMetrics(t, reader, expectedMetrics, false, false) } +func ExpectContainAndNotContainOtelMetrics( + t *testing.T, + reader sdkmetric.Reader, + expectedMetrics map[string]any, + notExpectedMetricKeys []string, +) { + foundMetricKeys := assertOtelMetrics(t, reader, expectedMetrics, false, false) + for _, key := range notExpectedMetricKeys { + assert.NotContains(t, foundMetricKeys, key) + } +} + func ExpectContainOtelMetricsKeys(t assert.TestingT, reader sdkmetric.Reader, expectedMetricsKeys []string) { expectedMetrics := make(map[string]any) for _, metricKey := range expectedMetricsKeys { @@ -62,7 +74,7 @@ func assertOtelMetrics( reader sdkmetric.Reader, expectedMetrics map[string]any, fullMatch, skipValAssert bool, -) { +) []string { var rm metricdata.ResourceMetrics assert.NoError(t, reader.Collect(context.Background(), &rm)) @@ -118,8 +130,9 @@ func assertOtelMetrics( expectedMetricsKeys = append(expectedMetricsKeys, k) } if fullMatch { - assert.ElementsMatch(t, expectedMetricsKeys, foundMetrics) + assert.ElementsMatch(t, foundMetrics, expectedMetricsKeys) } else { assert.Subset(t, foundMetrics, expectedMetricsKeys) } + return foundMetrics } diff --git a/x-pack/apm-server/sampling/processor.go b/x-pack/apm-server/sampling/processor.go index 3d50a9371fd..a66960b8f1a 100644 --- a/x-pack/apm-server/sampling/processor.go +++ b/x-pack/apm-server/sampling/processor.go @@ -8,18 +8,21 @@ import ( "context" "encoding/json" "os" + "runtime" "sync" "time" + "github.com/cespare/xxhash/v2" "github.com/pkg/errors" "go.opentelemetry.io/otel/metric" "golang.org/x/sync/errgroup" "github.com/elastic/apm-data/model/modelpb" + "github.com/elastic/elastic-agent-libs/logp" + "github.com/elastic/apm-server/internal/logs" "github.com/elastic/apm-server/x-pack/apm-server/sampling/eventstorage" "github.com/elastic/apm-server/x-pack/apm-server/sampling/pubsub" - "github.com/elastic/elastic-agent-libs/logp" ) const ( @@ -41,6 +44,7 @@ type Processor struct { eventStore eventstorage.RW eventMetrics eventMetrics + shardLock *shardLock stopMu sync.Mutex stopping chan struct{} @@ -71,6 +75,7 @@ func NewProcessor(config Config, logger *logp.Logger) (*Processor, error) { rateLimitedLogger: logger.WithOptions(logs.WithRateLimit(loggerRateLimit)), groups: newTraceGroups(meter, config.Policies, config.MaxDynamicServices, config.IngestRateDecayFactor), eventStore: config.Storage, + shardLock: newShardLock(runtime.GOMAXPROCS(0)), stopping: make(chan struct{}), stopped: make(chan struct{}), } @@ -169,6 +174,9 @@ func (p *Processor) processTransaction(event *modelpb.APMEvent) (report, stored return true, false, nil } + p.shardLock.RLock(event.Trace.Id) + defer p.shardLock.RUnlock(event.Trace.Id) + traceSampled, err := p.eventStore.IsTraceSampled(event.Trace.Id) switch err { case nil: @@ -229,6 +237,9 @@ sampling policies without service name specified. } func (p *Processor) processSpan(event *modelpb.APMEvent) (report, stored bool, _ error) { + p.shardLock.RLock(event.Trace.Id) + defer p.shardLock.RUnlock(event.Trace.Id) + traceSampled, err := p.eventStore.IsTraceSampled(event.Trace.Id) if err != nil { if err == eventstorage.ErrNotFound { @@ -442,14 +453,18 @@ func (p *Processor) Run() error { } } + // We lock before WriteTraceSampled here to prevent race condition with IsTraceSampled from incoming events. + p.shardLock.Lock(traceID) if err := p.eventStore.WriteTraceSampled(traceID, true); err != nil { p.rateLimitedLogger.Warnf( "received error writing sampled trace: %s", err, ) } + p.shardLock.Unlock(traceID) events = events[:0] - if err := p.eventStore.ReadTraceEvents(traceID, &events); err != nil { + err = p.eventStore.ReadTraceEvents(traceID, &events) + if err != nil { p.rateLimitedLogger.Warnf( "received error reading trace events: %s", err, ) @@ -535,3 +550,37 @@ func sendTraceIDs(ctx context.Context, out chan<- string, traceIDs []string) err } return nil } + +type shardLock struct { + locks []sync.RWMutex +} + +func newShardLock(numShards int) *shardLock { + if numShards <= 0 { + panic("shardLock numShards must be greater than zero") + } + locks := make([]sync.RWMutex, numShards) + return &shardLock{locks: locks} +} + +func (s *shardLock) Lock(id string) { + s.getLock(id).Lock() +} + +func (s *shardLock) Unlock(id string) { + s.getLock(id).Unlock() +} + +func (s *shardLock) RLock(id string) { + s.getLock(id).RLock() +} + +func (s *shardLock) RUnlock(id string) { + s.getLock(id).RUnlock() +} + +func (s *shardLock) getLock(id string) *sync.RWMutex { + var h xxhash.Digest + _, _ = h.WriteString(id) + return &s.locks[h.Sum64()%uint64(len(s.locks))] +} diff --git a/x-pack/apm-server/sampling/processor_test.go b/x-pack/apm-server/sampling/processor_test.go index 318cdedfc8b..5676f09d412 100644 --- a/x-pack/apm-server/sampling/processor_test.go +++ b/x-pack/apm-server/sampling/processor_test.go @@ -10,6 +10,8 @@ import ( "math/rand" "os" "path/filepath" + "sync" + "sync/atomic" "testing" "time" @@ -20,14 +22,16 @@ import ( "github.com/stretchr/testify/require" sdkmetric "go.opentelemetry.io/otel/sdk/metric" "go.opentelemetry.io/otel/sdk/metric/metricdata" + "golang.org/x/sync/errgroup" "google.golang.org/protobuf/testing/protocmp" "github.com/elastic/apm-data/model/modelpb" + "github.com/elastic/elastic-agent-libs/logp/logptest" + "github.com/elastic/apm-server/internal/beater/monitoringtest" "github.com/elastic/apm-server/x-pack/apm-server/sampling" "github.com/elastic/apm-server/x-pack/apm-server/sampling/eventstorage" "github.com/elastic/apm-server/x-pack/apm-server/sampling/pubsub/pubsubtest" - "github.com/elastic/elastic-agent-libs/logp/logptest" ) func TestProcessUnsampled(t *testing.T) { @@ -829,6 +833,85 @@ func TestGracefulShutdown(t *testing.T) { assert.Equal(t, int(sampleRate*float64(totalTraces)), count) } +func TestPotentialRaceConditionConcurrent(t *testing.T) { + flushInterval := 1 * time.Second + tempdirConfig := newTempdirConfig(t) + tempdirConfig.Config.FlushInterval = flushInterval + tempdirConfig.Config.Policies = []sampling.Policy{ + {SampleRate: 1.0}, + } + + var reportedMu sync.Mutex + reported := map[string]struct{}{} + tempdirConfig.Config.BatchProcessor = modelpb.ProcessBatchFunc(func(ctx context.Context, batch *modelpb.Batch) error { + reportedMu.Lock() + defer reportedMu.Unlock() + for _, b := range batch.Clone() { + reported[b.Transaction.Id] = struct{}{} + } + return nil + }) + + processor, err := sampling.NewProcessor(tempdirConfig.Config, logptest.NewTestingLogger(t, "")) + require.NoError(t, err) + go processor.Run() + defer processor.Stop(context.Background()) + + var processed atomic.Int64 + var lateArrivals atomic.Int64 + eg, ctx := errgroup.WithContext(context.Background()) + for i := 0; i < 1000; i++ { + eg.Go(func() error { + first := true + index := i * 100000000 + + timer := time.NewTimer(flushInterval * 2) + defer timer.Stop() + + for { + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return ctx.Err() + default: + } + + batch := modelpb.Batch{{ + Trace: &modelpb.Trace{Id: "trace1"}, + Transaction: &modelpb.Transaction{ + Type: "type", + Id: fmt.Sprintf("transaction%08d", index), + Sampled: true, + }, + }} + + if first { + first = false + } else { + batch[0].ParentId = fmt.Sprintf("bar%08d", index) + } + + if err := processor.ProcessBatch(ctx, &batch); err != nil { + return err + } + index++ + processed.Add(1) + lateArrivals.Add(int64(len(batch))) + time.Sleep(time.Duration(rand.Intn(5)) * time.Millisecond) + } + }) + } + + require.NoError(t, eg.Wait()) + require.NoError(t, processor.Stop(context.Background())) + + reportedMu.Lock() + defer reportedMu.Unlock() + reportedPlusLateArrivals := int64(len(reported)) + lateArrivals.Load() + assert.Equal(t, processed.Load(), reportedPlusLateArrivals) +} + type testConfig struct { sampling.Config tempDir string