diff --git a/pkg/chipingress/batch/client.go b/pkg/chipingress/batch/client.go index d0aa6f26d..e02704b26 100644 --- a/pkg/chipingress/batch/client.go +++ b/pkg/chipingress/batch/client.go @@ -3,10 +3,14 @@ package batch import ( "context" "errors" + "strconv" "sync" + "sync/atomic" "time" + cepb "github.com/cloudevents/sdk-go/binding/format/protobuf/v2/pb" "go.uber.org/zap" + "google.golang.org/protobuf/proto" "github.com/smartcontractkit/chainlink-common/pkg/chipingress" ) @@ -16,10 +20,16 @@ type messageWithCallback struct { callback func(error) } +type seqnumKey struct { + source string + eventType string +} + // Client is a batching client that accumulates messages and sends them in batches. type Client struct { client chipingress.Client batchSize int + cloneEvent bool maxConcurrentSends chan struct{} batchInterval time.Duration maxPublishTimeout time.Duration @@ -31,6 +41,7 @@ type Client struct { shutdownOnce sync.Once batcherDone chan struct{} cancelBatcher context.CancelFunc + counters sync.Map // map[seqnumKey]*atomic.Uint64 for per-(source,type) seqnum, cleared on Stop() } // Opt is a functional option for configuring the batch Client. @@ -42,6 +53,7 @@ func NewBatchClient(client chipingress.Client, opts ...Opt) (*Client, error) { client: client, log: zap.NewNop().Sugar(), batchSize: 10, + cloneEvent: true, maxConcurrentSends: make(chan struct{}, 1), messageBuffer: make(chan *messageWithCallback, 200), batchInterval: 100 * time.Millisecond, @@ -123,15 +135,37 @@ func (b *Client) Stop() { case <-ctx.Done(): // timeout or context cancelled b.log.Warnw("timed out waiting for shutdown to finish, force closing", "timeout", b.shutdownTimeout) } + + // Release per-stream seqnum state to avoid unbounded growth from high-cardinality source/type values. + b.clearCounters() }) } +func (b *Client) clearCounters() { + b.counters.Range(func(key, _ any) bool { + b.counters.Delete(key) + return true + }) +} + +// seqnumFor returns the next sequence number for the given source+type pair. +// Each unique (source, type) pair has its own independent counter starting at 1. +func (b *Client) seqnumFor(source, typ string) uint64 { + key := seqnumKey{source: source, eventType: typ} + v, _ := b.counters.LoadOrStore(key, &atomic.Uint64{}) + return v.(*atomic.Uint64).Add(1) +} + // QueueMessage queues a single message to the batch client with an optional callback. // The callback will be invoked after the batch containing this message is sent. // The callback receives an error parameter (nil on success). // Callbacks are invoked from goroutines // Returns immediately with no blocking - drops message if channel is full. // Returns an error if the message was dropped. +// QueueMessage stamps/overwrites the "seqnum" extension on the event it buffers. +// By default, it clones the input event first (WithEventClone(true)) so caller-owned +// objects are not mutated and queued snapshots remain immutable under pointer reuse. +// If cloning is disabled via WithEventClone(false), the caller event is mutated in place. func (b *Client) QueueMessage(event *chipingress.CloudEventPb, callback func(error)) error { if event == nil { return nil @@ -144,8 +178,29 @@ func (b *Client) QueueMessage(event *chipingress.CloudEventPb, callback func(err default: } + eventToQueue := event + if b.cloneEvent { + // Clone the caller-owned event so queued messages keep an immutable seqnum snapshot. + eventCopy, ok := proto.Clone(event).(*chipingress.CloudEventPb) + if !ok { + return errors.New("failed to clone event") + } + eventToQueue = eventCopy + } + + // Stamp seqnum extension attribute using the event snapshot being queued. + seq := b.seqnumFor(eventToQueue.Source, eventToQueue.Type) + if eventToQueue.Attributes == nil { + eventToQueue.Attributes = make(map[string]*cepb.CloudEventAttributeValue) + } + eventToQueue.Attributes["seqnum"] = &cepb.CloudEventAttributeValue{ + Attr: &cepb.CloudEventAttributeValue_CeString{ + CeString: strconv.FormatUint(seq, 10), + }, + } + msg := &messageWithCallback{ - event: event, + event: eventToQueue, callback: callback, } @@ -198,6 +253,14 @@ func WithBatchSize(batchSize int) Opt { } } +// WithEventClone controls whether QueueMessage clones events before stamping seqnum and buffering. +// Defaults to true for safety when caller reuses event pointers. +func WithEventClone(clone bool) Opt { + return func(c *Client) { + c.cloneEvent = clone + } +} + // WithMaxConcurrentSends sets the maximum number of concurrent batch send operations func WithMaxConcurrentSends(maxConcurrentSends int) Opt { return func(c *Client) { diff --git a/pkg/chipingress/batch/client_test.go b/pkg/chipingress/batch/client_test.go index 1f7567751..7f8c356fb 100644 --- a/pkg/chipingress/batch/client_test.go +++ b/pkg/chipingress/batch/client_test.go @@ -2,7 +2,9 @@ package batch import ( "context" + "sort" "strconv" + "sync" "testing" "time" @@ -27,6 +29,16 @@ func TestNewBatchClient(t *testing.T) { assert.Equal(t, 100, client.batchSize) }) + t.Run("WithEventClone", func(t *testing.T) { + client, err := NewBatchClient(nil) + require.NoError(t, err) + assert.True(t, client.cloneEvent) + + client, err = NewBatchClient(nil, WithEventClone(false)) + require.NoError(t, err) + assert.False(t, client.cloneEvent) + }) + t.Run("WithMaxConcurrentSends", func(t *testing.T) { client, err := NewBatchClient(nil, WithMaxConcurrentSends(10)) require.NoError(t, err) @@ -853,4 +865,247 @@ func TestStop(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "shutdown") }) + + t.Run("clears seqnum counters on Stop", func(t *testing.T) { + mockClient := mocks.NewClient(t) + client, err := NewBatchClient(mockClient, WithBatchSize(10)) + require.NoError(t, err) + + _ = client.seqnumFor("domain-a", "entity-x") + _ = client.seqnumFor("domain-b", "entity-y") + assert.Equal(t, 2, countCounters(&client.counters)) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + client.Start(ctx) + client.Stop() + + assert.Equal(t, 0, countCounters(&client.counters)) + }) +} + +func countCounters(counters *sync.Map) int { + n := 0 + counters.Range(func(_, _ any) bool { + n++ + return true + }) + return n +} + +func TestSeqnum(t *testing.T) { + t.Run("dropped messages consume seqnum and create detectable gaps", func(t *testing.T) { + client, err := NewBatchClient(nil, WithMessageBuffer(1)) + require.NoError(t, err) + + first := &chipingress.CloudEventPb{Id: "id-1", Source: "domain-a", Type: "entity-x"} + second := &chipingress.CloudEventPb{Id: "id-2", Source: "domain-a", Type: "entity-x"} + third := &chipingress.CloudEventPb{Id: "id-3", Source: "domain-a", Type: "entity-x"} + + err = client.QueueMessage(first, nil) + require.NoError(t, err) + + err = client.QueueMessage(second, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "message buffer is full") + + msg := <-client.messageBuffer + require.NotNil(t, msg.event.Attributes["seqnum"]) + assert.Equal(t, "1", msg.event.Attributes["seqnum"].GetCeString()) + + err = client.QueueMessage(third, nil) + require.NoError(t, err) + + msg = <-client.messageBuffer + require.NotNil(t, msg.event.Attributes["seqnum"]) + assert.Equal(t, "3", msg.event.Attributes["seqnum"].GetCeString()) + }) + + t.Run("reusing event pointer preserves queued seqnum snapshots", func(t *testing.T) { + client, err := NewBatchClient(nil, WithMessageBuffer(2)) + require.NoError(t, err) + + event := &chipingress.CloudEventPb{Id: "id-1", Source: "domain-a", Type: "entity-x"} + + err = client.QueueMessage(event, nil) + require.NoError(t, err) + err = client.QueueMessage(event, nil) + require.NoError(t, err) + + first := <-client.messageBuffer + second := <-client.messageBuffer + + require.NotNil(t, first.event.Attributes["seqnum"]) + require.NotNil(t, second.event.Attributes["seqnum"]) + assert.Equal(t, "1", first.event.Attributes["seqnum"].GetCeString()) + assert.Equal(t, "2", second.event.Attributes["seqnum"].GetCeString()) + }) + + t.Run("reusing event pointer can overwrite queued seqnum when clone disabled", func(t *testing.T) { + client, err := NewBatchClient(nil, WithMessageBuffer(2), WithEventClone(false)) + require.NoError(t, err) + + event := &chipingress.CloudEventPb{Id: "id-1", Source: "domain-a", Type: "entity-x"} + + err = client.QueueMessage(event, nil) + require.NoError(t, err) + err = client.QueueMessage(event, nil) + require.NoError(t, err) + + first := <-client.messageBuffer + second := <-client.messageBuffer + + require.NotNil(t, first.event.Attributes["seqnum"]) + require.NotNil(t, second.event.Attributes["seqnum"]) + assert.Equal(t, "2", first.event.Attributes["seqnum"].GetCeString()) + assert.Equal(t, "2", second.event.Attributes["seqnum"].GetCeString()) + }) + + t.Run("stamps sequential seqnum for same source+type", func(t *testing.T) { + client, err := NewBatchClient(nil, WithMessageBuffer(10)) + require.NoError(t, err) + + events := []*chipingress.CloudEventPb{ + {Id: "id-1", Source: "domain-a", Type: "entity-x"}, + {Id: "id-2", Source: "domain-a", Type: "entity-x"}, + {Id: "id-3", Source: "domain-a", Type: "entity-x"}, + } + + for _, e := range events { + err := client.QueueMessage(e, nil) + require.NoError(t, err) + } + + // Drain buffer and verify seqnums + for i, expected := range []string{"1", "2", "3"} { + msg := <-client.messageBuffer + require.NotNil(t, msg.event.Attributes, "event %d should have attributes", i) + seqAttr, ok := msg.event.Attributes["seqnum"] + require.True(t, ok, "event %d should have seqnum attribute", i) + assert.Equal(t, expected, seqAttr.GetCeString(), "event %d seqnum mismatch", i) + } + }) + + t.Run("independent counters per source+type pair", func(t *testing.T) { + client, err := NewBatchClient(nil, WithMessageBuffer(10)) + require.NoError(t, err) + + // Queue events with different source+type combinations + events := []*chipingress.CloudEventPb{ + {Id: "a1", Source: "domain-a", Type: "entity-x"}, + {Id: "b1", Source: "domain-b", Type: "entity-y"}, + {Id: "a2", Source: "domain-a", Type: "entity-x"}, + {Id: "b2", Source: "domain-b", Type: "entity-y"}, + {Id: "c1", Source: "domain-a", Type: "entity-z"}, // same domain, different type + } + + for _, e := range events { + err := client.QueueMessage(e, nil) + require.NoError(t, err) + } + + // Expected seqnums by event ID + expected := map[string]string{ + "a1": "1", // first for domain-a/entity-x + "b1": "1", // first for domain-b/entity-y + "a2": "2", // second for domain-a/entity-x + "b2": "2", // second for domain-b/entity-y + "c1": "1", // first for domain-a/entity-z (new type) + } + + for range events { + msg := <-client.messageBuffer + seqAttr := msg.event.Attributes["seqnum"] + require.NotNil(t, seqAttr) + assert.Equal(t, expected[msg.event.Id], seqAttr.GetCeString(), + "seqnum mismatch for event %s", msg.event.Id) + } + }) + + t.Run("source and type values containing separator do not collide", func(t *testing.T) { + client, err := NewBatchClient(nil, WithMessageBuffer(10)) + require.NoError(t, err) + + events := []*chipingress.CloudEventPb{ + {Id: "sep-1", Source: "a\x00b", Type: "c"}, + {Id: "sep-2", Source: "a", Type: "b\x00c"}, + } + + for _, e := range events { + err := client.QueueMessage(e, nil) + require.NoError(t, err) + } + + expected := map[string]string{ + "sep-1": "1", + "sep-2": "1", + } + + for range events { + msg := <-client.messageBuffer + seqAttr := msg.event.Attributes["seqnum"] + require.NotNil(t, seqAttr) + assert.Equal(t, expected[msg.event.Id], seqAttr.GetCeString(), + "seqnum mismatch for event %s", msg.event.Id) + } + }) + + t.Run("concurrent access produces unique seqnums", func(t *testing.T) { + client, err := NewBatchClient(nil, WithMessageBuffer(1000)) + require.NoError(t, err) + + const numGoroutines = 50 + const eventsPerGoroutine = 10 + totalEvents := numGoroutines * eventsPerGoroutine + + queueErrors := make(chan error, totalEvents) + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for g := 0; g < numGoroutines; g++ { + go func(goroutineID int) { + defer wg.Done() + for i := 0; i < eventsPerGoroutine; i++ { + event := &chipingress.CloudEventPb{ + Id: strconv.Itoa(goroutineID*eventsPerGoroutine + i), + Source: "concurrent-domain", + Type: "concurrent-type", + } + if err := client.QueueMessage(event, nil); err != nil { + queueErrors <- err + } + } + }(g) + } + + wg.Wait() + close(queueErrors) + + var collectedQueueErrors []error + for err := range queueErrors { + collectedQueueErrors = append(collectedQueueErrors, err) + } + require.Empty(t, collectedQueueErrors, "expected all concurrent QueueMessage calls to succeed") + + // Collect all seqnums + seqnums := make([]uint64, 0, totalEvents) + for i := 0; i < totalEvents; i++ { + msg := <-client.messageBuffer + seqAttr := msg.event.Attributes["seqnum"] + require.NotNil(t, seqAttr) + seq, err := strconv.ParseUint(seqAttr.GetCeString(), 10, 64) + require.NoError(t, err) + seqnums = append(seqnums, seq) + } + + // Sort and verify all unique and in range [1, totalEvents] + sort.Slice(seqnums, func(i, j int) bool { return seqnums[i] < seqnums[j] }) + + expectedSeq := uint64(1) + for i, seq := range seqnums { + assert.Equal(t, expectedSeq, seq, "seqnum at index %d should be %d", i, expectedSeq) + expectedSeq++ + } + }) }