diff --git a/connectors/common/base.go b/connectors/common/base.go index 34ae579e..840f86b6 100644 --- a/connectors/common/base.go +++ b/connectors/common/base.go @@ -34,6 +34,7 @@ import ( "go.akshayshah.org/memhttp" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" + "golang.org/x/time/rate" ) const progressReportingIntervalSec = 10 @@ -51,6 +52,7 @@ type ConnectorSettings struct { TransformClient adiomv1connect.TransformServiceClient SourceDataType adiomv1.DataType DestinationDataType adiomv1.DataType + WriteRateLimit int } type maybeOptimizedConnectorService interface { @@ -81,6 +83,8 @@ type connector struct { progressTracker *ProgressTracker namespaceMappings map[string]string + + limiter *rate.Limiter } func isRetryable(err error) bool { @@ -944,6 +948,9 @@ func (c *connector) StartWriteFromChannel(flowId iface.FlowID, dataChannelID ifa if dataMsg.MutationType == iface.MutationType_Barrier { err := flowParallelWriter.ScheduleBarrier(dataMsg) if err != nil { + if errors.Is(err, WriterClosedErr) { + break + } slog.Error(fmt.Sprintf("Failed to schedule barrier message: %v", err)) } } else { @@ -951,6 +958,9 @@ func (c *connector) StartWriteFromChannel(flowId iface.FlowID, dataChannelID ifa writerProgress.dataMessages.Add(1) err := flowParallelWriter.ScheduleDataMessage(dataMsg) if err != nil { + if errors.Is(err, WriterClosedErr) { + break + } slog.Error(fmt.Sprintf("Failed to schedule data message: %v", err)) } } @@ -1048,12 +1058,17 @@ func NewConnector(desc string, impl adiomv1connect.ConnectorServiceClient, under if maybeOptimizedConnectorService == nil { underlying = impl } + var limiter *rate.Limiter + if settings.WriteRateLimit > 0 { + limiter = rate.NewLimiter(rate.Limit(settings.WriteRateLimit), settings.WriteRateLimit) + } return &connector{ desc: desc, impl: impl, maybeOptimizedImpl: maybeOptimizedConnectorService, settings: settings, namespaceMappings: map[string]string{}, + limiter: limiter, } } @@ -1064,6 +1079,11 @@ func (c *connector) ProcessDataMessages(dataMsgs []iface.DataMessage) error { switch dataMsg.MutationType { case iface.MutationType_InsertBatch: if len(msgs) > 0 { + if c.limiter != nil { + if err := c.limiter.WaitN(c.flowCtx, len(msgs)); err != nil { + return err + } + } ns := dataMsg.Loc _, err := c.maybeOptimizedImpl.WriteUpdates(c.flowCtx, connect.NewRequest(&adiomv1.WriteUpdatesRequest{ Namespace: ns, @@ -1076,6 +1096,11 @@ func (c *connector) ProcessDataMessages(dataMsgs []iface.DataMessage) error { c.progressTracker.UpdateWriteLSN(dataMsgs[i-1].SeqNum) msgs = nil } + if c.limiter != nil { + if err := c.limiter.WaitN(c.flowCtx, len(dataMsg.DataBatch)); err != nil { + return err + } + } _, err := c.maybeOptimizedImpl.WriteData(c.flowCtx, connect.NewRequest(&adiomv1.WriteDataRequest{ Namespace: dataMsg.Loc, Data: dataMsg.DataBatch, @@ -1107,6 +1132,11 @@ func (c *connector) ProcessDataMessages(dataMsgs []iface.DataMessage) error { } } if len(msgs) > 0 { + if c.limiter != nil { + if err := c.limiter.WaitN(c.flowCtx, len(msgs)); err != nil { + return err + } + } ns := dataMsgs[0].Loc _, err := c.impl.WriteUpdates(c.flowCtx, connect.NewRequest(&adiomv1.WriteUpdatesRequest{ Namespace: ns, diff --git a/connectors/common/parallel_writer.go b/connectors/common/parallel_writer.go index 2a9522ca..1c58a0c7 100644 --- a/connectors/common/parallel_writer.go +++ b/connectors/common/parallel_writer.go @@ -8,6 +8,7 @@ package common import ( "context" + "errors" "fmt" "hash/fnv" "log/slog" @@ -19,6 +20,8 @@ import ( "golang.org/x/sync/errgroup" ) +var WriterClosedErr = errors.New("writer closed") + type ParallelWriterConnector interface { HandleBarrierMessage(iface.DataMessage) error ProcessDataMessages([]iface.DataMessage) error @@ -161,7 +164,7 @@ func (bwa *ParallelWriter) ScheduleBarrier(barrierMsg iface.DataMessage) error { select { case <-bwa.blockBarrier: case <-bwa.ctx.Done(): - return fmt.Errorf("writer closed") + return WriterClosedErr } } slog.Debug("Blocking barrier unblocked.") @@ -244,6 +247,9 @@ func (ww *writerWorker) run() error { if len(batch) > 0 { err := ww.parallelWriter.connector.ProcessDataMessages(batch) if err != nil { + if errors.Is(err, context.Canceled) { + return err + } slog.Error(fmt.Sprintf("Worker %v failed to process data messages: %v", ww.id, err)) if err2 := ww.parallelWriter.connector.HandlerError(err); err2 != nil { return err2 @@ -253,6 +259,9 @@ func (ww *writerWorker) run() error { } if len(multiBatch) > 0 { if err := ww.sendMultiBatch(multiBatch); err != nil { + if errors.Is(err, context.Canceled) { + return err + } slog.Error(fmt.Sprintf("Worker %v failed to process data messages: %v", ww.id, err)) if err2 := ww.parallelWriter.connector.HandlerError(err); err2 != nil { return err2 @@ -288,6 +297,9 @@ func (ww *writerWorker) run() error { if isLastWorker { err := ww.parallelWriter.connector.HandleBarrierMessage(msg) if err != nil { + if errors.Is(err, context.Canceled) { + return err + } slog.Error(fmt.Sprintf("Worker %v failed to handle barrier message: %v", ww.id, err)) if err2 := ww.parallelWriter.connector.HandlerError(err); err2 != nil { return err2 @@ -303,6 +315,9 @@ func (ww *writerWorker) run() error { multiBatchCount += 1 if (ww.parallelWriter.maxBatchSize > 0 && multiBatchCount >= ww.parallelWriter.maxBatchSize) || msg.MutationType == iface.MutationType_InsertBatch || len(ww.queue) == 0 { if err := ww.sendMultiBatch(multiBatch); err != nil { + if errors.Is(err, context.Canceled) { + return err + } slog.Error(fmt.Sprintf("Worker %v failed to process data messages: %v", ww.id, err)) if err2 := ww.parallelWriter.connector.HandlerError(err); err2 != nil { return err2 @@ -318,6 +333,9 @@ func (ww *writerWorker) run() error { if len(batch) > 0 && msg.Loc != batch[0].Loc { err := ww.parallelWriter.connector.ProcessDataMessages(batch) if err != nil { + if errors.Is(err, context.Canceled) { + return err + } slog.Error(fmt.Sprintf("Worker %v failed to process data messages: %v", ww.id, err)) if err2 := ww.parallelWriter.connector.HandlerError(err); err2 != nil { return err2 @@ -331,6 +349,9 @@ func (ww *writerWorker) run() error { if (ww.parallelWriter.maxBatchSize > 0 && len(batch) >= ww.parallelWriter.maxBatchSize) || msg.MutationType == iface.MutationType_InsertBatch || len(ww.queue) == 0 { err := ww.parallelWriter.connector.ProcessDataMessages(batch) if err != nil { + if errors.Is(err, context.Canceled) { + return err + } if msg.MutationType == iface.MutationType_InsertBatch { d := 0 if batch[0].Data != nil { @@ -355,14 +376,14 @@ func (ww *writerWorker) addMessage(msg iface.DataMessage) error { select { case ww.queue <- msg: case <-ww.parallelWriter.ctx.Done(): - return fmt.Errorf("writer closed") + return WriterClosedErr } if ww.clogSize > 0 && msg.MutationType == iface.MutationType_InsertBatch { for range ww.clogSize { select { case ww.queue <- iface.DataMessage{MutationType: iface.MutationType_Ignore}: case <-ww.parallelWriter.ctx.Done(): - return fmt.Errorf("writer closed") + return WriterClosedErr } } } diff --git a/internal/app/app.go b/internal/app/app.go index 8cb174a5..fb2c7a0c 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -243,6 +243,7 @@ func runDsync(c *cli.Context) error { MultinamespaceBatcher: o.MultinamespaceBatcher, SyncMode: o.Mode, ReverseRequestedFlag: o.Reverse, + WriteRateLimit: o.WriteRateLimit, }) var wg sync.WaitGroup diff --git a/internal/app/options/flags.go b/internal/app/options/flags.go index 713eea54..974e86d8 100644 --- a/internal/app/options/flags.go +++ b/internal/app/options/flags.go @@ -184,6 +184,10 @@ func GetFlagsAndBeforeFunc() ([]cli.Flag, cli.BeforeFunc) { Required: false, Hidden: true, }), + altsrc.NewIntFlag(&cli.IntFlag{ + Name: "write-rate-limit", + Usage: "If set, overall per second rate limit on the destination. If set, must be larger than any batch sizes.", + }), } before := func(c *cli.Context) error { diff --git a/internal/app/options/options.go b/internal/app/options/options.go index c6520f28..80142471 100644 --- a/internal/app/options/options.go +++ b/internal/app/options/options.go @@ -42,6 +42,8 @@ type Options struct { WriterMaxBatchSize int MultinamespaceBatcher bool Mode string + + WriteRateLimit int } // works with a copy of the struct to avoid modifying the original @@ -78,5 +80,7 @@ func NewFromCLIContext(c *cli.Context) (Options, error) { o.Mode = c.String("mode") o.Reverse = c.Bool("reverse") + o.WriteRateLimit = c.Int("write-rate-limit") + return o, nil } diff --git a/runners/local/runner.go b/runners/local/runner.go index 99b9105b..2de00b0f 100644 --- a/runners/local/runner.go +++ b/runners/local/runner.go @@ -79,6 +79,8 @@ type RunnerLocalSettings struct { WriterMaxBatchSize int SyncMode string MultinamespaceBatcher bool + + WriteRateLimit int } const ( @@ -109,6 +111,7 @@ func NewRunnerLocal(settings RunnerLocalSettings) *RunnerLocal { TransformClient: settings.TransformClient, SourceDataType: settings.SrcDataType, DestinationDataType: settings.DstDataType, + WriteRateLimit: settings.WriteRateLimit, } if settings.LoadLevel != "" { btc := GetBaseThreadCount(settings.LoadLevel) @@ -307,6 +310,7 @@ func (r *RunnerLocal) GracefulShutdown() { if r.cancelIntegrityCtx != nil { r.cancelIntegrityCtx() } + _ = r.dst.Interrupt(r.activeFlowID) } func (r *RunnerLocal) Teardown() {