Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions internal/beater/monitoringtest/opentelemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ func ExpectContainOtelMetrics(
assertOtelMetrics(t, reader, expectedMetrics, false, false)
}

func ExpectContainAndNotContainOtelMetrics(
t *testing.T,
reader sdkmetric.Reader,
expectedMetrics map[string]any,
notExpectedMetricKeys []string,
) {
foundMetricKeys := assertOtelMetrics(t, reader, expectedMetrics, false, false)
for _, key := range notExpectedMetricKeys {
assert.NotContains(t, foundMetricKeys, key)
}
}

func ExpectContainOtelMetricsKeys(t assert.TestingT, reader sdkmetric.Reader, expectedMetricsKeys []string) {
expectedMetrics := make(map[string]any)
for _, metricKey := range expectedMetricsKeys {
Expand All @@ -62,7 +74,7 @@ func assertOtelMetrics(
reader sdkmetric.Reader,
expectedMetrics map[string]any,
fullMatch, skipValAssert bool,
) {
) []string {
var rm metricdata.ResourceMetrics
assert.NoError(t, reader.Collect(context.Background(), &rm))

Expand Down Expand Up @@ -118,8 +130,9 @@ func assertOtelMetrics(
expectedMetricsKeys = append(expectedMetricsKeys, k)
}
if fullMatch {
assert.ElementsMatch(t, expectedMetricsKeys, foundMetrics)
assert.ElementsMatch(t, foundMetrics, expectedMetricsKeys)
} else {
assert.Subset(t, foundMetrics, expectedMetricsKeys)
}
return foundMetrics
}
53 changes: 51 additions & 2 deletions x-pack/apm-server/sampling/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@ import (
"context"
"encoding/json"
"os"
"runtime"
"sync"
"time"

"github.com/cespare/xxhash/v2"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/metric"
"golang.org/x/sync/errgroup"

"github.com/elastic/apm-data/model/modelpb"
"github.com/elastic/elastic-agent-libs/logp"

"github.com/elastic/apm-server/internal/logs"
"github.com/elastic/apm-server/x-pack/apm-server/sampling/eventstorage"
"github.com/elastic/apm-server/x-pack/apm-server/sampling/pubsub"
"github.com/elastic/elastic-agent-libs/logp"
)

const (
Expand All @@ -41,6 +44,7 @@ type Processor struct {

eventStore eventstorage.RW
eventMetrics eventMetrics
shardLock *shardLock

stopMu sync.Mutex
stopping chan struct{}
Expand Down Expand Up @@ -71,6 +75,7 @@ func NewProcessor(config Config, logger *logp.Logger) (*Processor, error) {
rateLimitedLogger: logger.WithOptions(logs.WithRateLimit(loggerRateLimit)),
groups: newTraceGroups(meter, config.Policies, config.MaxDynamicServices, config.IngestRateDecayFactor),
eventStore: config.Storage,
shardLock: newShardLock(runtime.GOMAXPROCS(0)),
stopping: make(chan struct{}),
stopped: make(chan struct{}),
}
Expand Down Expand Up @@ -169,6 +174,9 @@ func (p *Processor) processTransaction(event *modelpb.APMEvent) (report, stored
return true, false, nil
}

p.shardLock.RLock(event.Trace.Id)
defer p.shardLock.RUnlock(event.Trace.Id)

traceSampled, err := p.eventStore.IsTraceSampled(event.Trace.Id)
switch err {
case nil:
Expand Down Expand Up @@ -229,6 +237,9 @@ sampling policies without service name specified.
}

func (p *Processor) processSpan(event *modelpb.APMEvent) (report, stored bool, _ error) {
p.shardLock.RLock(event.Trace.Id)
defer p.shardLock.RUnlock(event.Trace.Id)

traceSampled, err := p.eventStore.IsTraceSampled(event.Trace.Id)
if err != nil {
if err == eventstorage.ErrNotFound {
Expand Down Expand Up @@ -442,14 +453,18 @@ func (p *Processor) Run() error {
}
}

// We lock before WriteTraceSampled here to prevent race condition with IsTraceSampled from incoming events.
p.shardLock.Lock(traceID)
if err := p.eventStore.WriteTraceSampled(traceID, true); err != nil {
p.rateLimitedLogger.Warnf(
"received error writing sampled trace: %s", err,
)
}
p.shardLock.Unlock(traceID)

events = events[:0]
if err := p.eventStore.ReadTraceEvents(traceID, &events); err != nil {
err = p.eventStore.ReadTraceEvents(traceID, &events)
if err != nil {
p.rateLimitedLogger.Warnf(
"received error reading trace events: %s", err,
)
Expand Down Expand Up @@ -535,3 +550,37 @@ func sendTraceIDs(ctx context.Context, out chan<- string, traceIDs []string) err
}
return nil
}

type shardLock struct {
locks []sync.RWMutex
}

func newShardLock(numShards int) *shardLock {
if numShards <= 0 {
panic("shardLock numShards must be greater than zero")
}
locks := make([]sync.RWMutex, numShards)
return &shardLock{locks: locks}
}

func (s *shardLock) Lock(id string) {
s.getLock(id).Lock()
}

func (s *shardLock) Unlock(id string) {
s.getLock(id).Unlock()
}

func (s *shardLock) RLock(id string) {
s.getLock(id).RLock()
}

func (s *shardLock) RUnlock(id string) {
s.getLock(id).RUnlock()
}

func (s *shardLock) getLock(id string) *sync.RWMutex {
var h xxhash.Digest
_, _ = h.WriteString(id)
return &s.locks[h.Sum64()%uint64(len(s.locks))]
}
85 changes: 84 additions & 1 deletion x-pack/apm-server/sampling/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"math/rand"
"os"
"path/filepath"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -20,14 +22,16 @@ import (
"github.com/stretchr/testify/require"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/testing/protocmp"

"github.com/elastic/apm-data/model/modelpb"
"github.com/elastic/elastic-agent-libs/logp/logptest"

"github.com/elastic/apm-server/internal/beater/monitoringtest"
"github.com/elastic/apm-server/x-pack/apm-server/sampling"
"github.com/elastic/apm-server/x-pack/apm-server/sampling/eventstorage"
"github.com/elastic/apm-server/x-pack/apm-server/sampling/pubsub/pubsubtest"
"github.com/elastic/elastic-agent-libs/logp/logptest"
)

func TestProcessUnsampled(t *testing.T) {
Expand Down Expand Up @@ -829,6 +833,85 @@ func TestGracefulShutdown(t *testing.T) {
assert.Equal(t, int(sampleRate*float64(totalTraces)), count)
}

func TestPotentialRaceConditionConcurrent(t *testing.T) {
flushInterval := 1 * time.Second
tempdirConfig := newTempdirConfig(t)
tempdirConfig.Config.FlushInterval = flushInterval
tempdirConfig.Config.Policies = []sampling.Policy{
{SampleRate: 1.0},
}

var reportedMu sync.Mutex
reported := map[string]struct{}{}
tempdirConfig.Config.BatchProcessor = modelpb.ProcessBatchFunc(func(ctx context.Context, batch *modelpb.Batch) error {
reportedMu.Lock()
defer reportedMu.Unlock()
for _, b := range batch.Clone() {
reported[b.Transaction.Id] = struct{}{}
}
return nil
})

processor, err := sampling.NewProcessor(tempdirConfig.Config, logptest.NewTestingLogger(t, ""))
require.NoError(t, err)
go processor.Run()
defer processor.Stop(context.Background())

var processed atomic.Int64
var lateArrivals atomic.Int64
eg, ctx := errgroup.WithContext(context.Background())
for i := 0; i < 1000; i++ {
eg.Go(func() error {
first := true
index := i * 100000000

timer := time.NewTimer(flushInterval * 2)
defer timer.Stop()

for {
select {
case <-timer.C:
return nil
case <-ctx.Done():
return ctx.Err()
default:
}

batch := modelpb.Batch{{
Trace: &modelpb.Trace{Id: "trace1"},
Transaction: &modelpb.Transaction{
Type: "type",
Id: fmt.Sprintf("transaction%08d", index),
Sampled: true,
},
}}

if first {
first = false
} else {
batch[0].ParentId = fmt.Sprintf("bar%08d", index)
}

if err := processor.ProcessBatch(ctx, &batch); err != nil {
return err
}
index++
processed.Add(1)
lateArrivals.Add(int64(len(batch)))
time.Sleep(time.Duration(rand.Intn(5)) * time.Millisecond)
}
})
}

require.NoError(t, eg.Wait())
require.NoError(t, processor.Stop(context.Background()))

reportedMu.Lock()
defer reportedMu.Unlock()
reportedPlusLateArrivals := int64(len(reported)) + lateArrivals.Load()
assert.Equal(t, processed.Load(), reportedPlusLateArrivals)
}

type testConfig struct {
sampling.Config
tempDir string
Expand Down