diff --git a/connectors/s3vector/conn.go b/connectors/s3vector/conn.go index 89e50516..ac411bc1 100644 --- a/connectors/s3vector/conn.go +++ b/connectors/s3vector/conn.go @@ -19,7 +19,6 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsontype" "golang.org/x/sync/errgroup" - "golang.org/x/time/rate" ) type conn struct { @@ -28,7 +27,7 @@ type conn struct { vectorKey string maxParallelism int batchSize int - limiter *rate.Limiter + limiter util.NamespaceLimiter } // GeneratePlan implements adiomv1connect.ConnectorServiceHandler. @@ -87,9 +86,10 @@ func (c *conn) WriteData(ctx context.Context, r *connect.Request[adiomv1.WriteDa eg, ctx := errgroup.WithContext(ctx) eg.SetLimit(c.maxParallelism) + limiter := c.limiter.Get(r.Msg.GetNamespace()) for batch := range slices.Chunk(vectors, c.batchSize) { eg.Go(func() error { - if err := c.limiter.WaitN(ctx, len(batch)); err != nil { + if err := limiter.WaitN(ctx, len(batch)); err != nil { return fmt.Errorf("err in limiter: %w", err) } if _, err := c.client.PutVectors(ctx, &s3vectors.PutVectorsInput{ @@ -220,11 +220,12 @@ func (c *conn) WriteUpdates(ctx context.Context, r *connect.Request[adiomv1.Writ } } if len(updates) > 0 { + limiter := c.limiter.Get(r.Msg.GetNamespace()) eg, ctx := errgroup.WithContext(ctx) eg.SetLimit(c.maxParallelism) for batch := range slices.Chunk(toDelete, c.batchSize) { eg.Go(func() error { - if err := c.limiter.WaitN(ctx, len(batch)); err != nil { + if err := limiter.WaitN(ctx, len(batch)); err != nil { return fmt.Errorf("err in limiter: %w", err) } if _, err := c.client.DeleteVectors(ctx, &s3vectors.DeleteVectorsInput{ @@ -239,7 +240,7 @@ func (c *conn) WriteUpdates(ctx context.Context, r *connect.Request[adiomv1.Writ } for batch := range slices.Chunk(vectors, c.batchSize) { eg.Go(func() error { - if err := c.limiter.WaitN(ctx, len(batch)); err != nil { + if err := limiter.WaitN(ctx, len(batch)); err != nil { return fmt.Errorf("err in limiter: %w", err) } if _, err := c.client.PutVectors(ctx, &s3vectors.PutVectorsInput{ @@ -259,7 +260,7 @@ func (c *conn) WriteUpdates(ctx context.Context, r *connect.Request[adiomv1.Writ return connect.NewResponse(&adiomv1.WriteUpdatesResponse{}), nil } -func NewConn(bucketName string, vectorKey string, maxParallelism int, batchSize int, limiter *rate.Limiter) (*conn, error) { +func NewConn(bucketName string, vectorKey string, maxParallelism int, batchSize int, rateLimit int) (*conn, error) { awsConfig, err := config.LoadDefaultConfig(context.Background()) if err != nil { return nil, err @@ -271,7 +272,7 @@ func NewConn(bucketName string, vectorKey string, maxParallelism int, batchSize vectorKey: vectorKey, maxParallelism: maxParallelism, batchSize: batchSize, - limiter: limiter, + limiter: util.NewNamespaceLimiter(nil, rateLimit), }, nil } diff --git a/connectors/util/limiter.go b/connectors/util/limiter.go new file mode 100644 index 00000000..edafdf9c --- /dev/null +++ b/connectors/util/limiter.go @@ -0,0 +1,57 @@ +package util + +import ( + "sync" + + "golang.org/x/time/rate" +) + +type NamespaceLimiter interface { + Get(namespace string) *rate.Limiter +} + +type namespaceLimiter struct { + mut sync.RWMutex + limiters map[string]*rate.Limiter + defaultLimiterFactory func() *rate.Limiter +} + +func (l *namespaceLimiter) Get(namespace string) *rate.Limiter { + l.mut.RLock() + limiter, ok := l.limiters[namespace] + l.mut.RUnlock() + if !ok { + l.mut.Lock() + limiter2, ok2 := l.limiters[namespace] + if ok2 { + l.mut.Unlock() + return limiter2 + } + limiter = l.defaultLimiterFactory() + l.limiters[namespace] = limiter + l.mut.Unlock() + return limiter + } + return limiter +} + +func NewNamespaceLimiter(namespaceToLimit map[string]int, defaultLimit int) *namespaceLimiter { + limiters := map[string]*rate.Limiter{} + for k, v := range namespaceToLimit { + limit := rate.Limit(v) + if v < 0 { + limit = rate.Inf + } + limiters[k] = rate.NewLimiter(limit, v) + } + limit := rate.Limit(defaultLimit) + if defaultLimit < 0 { + limit = rate.Inf + } + return &namespaceLimiter{ + limiters: limiters, + defaultLimiterFactory: func() *rate.Limiter { + return rate.NewLimiter(limit, defaultLimit) + }, + } +} diff --git a/internal/app/options/connectorflags.go b/internal/app/options/connectorflags.go index 130e1909..c917c19e 100644 --- a/internal/app/options/connectorflags.go +++ b/internal/app/options/connectorflags.go @@ -25,7 +25,6 @@ import ( "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" "golang.org/x/net/http2" - "golang.org/x/time/rate" ) var ErrMissingConnector = errors.New("missing or unsupported connector") @@ -380,17 +379,12 @@ func GetRegisteredConnectors() []RegisteredConnector { }, &cli.IntFlag{ Name: "rate-limit", - Value: 500, - Usage: "Max vectors per second across workers on this process.", - }, - &cli.IntFlag{ - Name: "rate-limit-burst", - Value: 700, - Usage: "Max vectors per second across workers on this process (burst). Will be set to at least `rate-limit`.", + Value: 2500, + Usage: "Max vectors per second across workers on this process per namespace (vector index).", }, &cli.IntFlag{ Name: "batch-size", - Value: 200, + Value: 500, Usage: "Max size of each PutVector requests (aws hard limit is 500).", }, }, func(c *cli.Context, args []string, _ AdditionalSettings) (adiomv1connect.ConnectorServiceHandler, error) { @@ -398,10 +392,8 @@ func GetRegisteredConnectors() []RegisteredConnector { vectorKey := c.String("vector-key") maxParallelism := c.Int("max-parallelism") rateLimit := c.Int("rate-limit") - rateLimitBurst := max(c.Int("rate-limit-burst"), rateLimit) batchSize := c.Int("batch-size") - limiter := rate.NewLimiter(rate.Limit(rateLimit), rateLimitBurst) - return s3vector.NewConn(bucket, vectorKey, maxParallelism, batchSize, limiter) + return s3vector.NewConn(bucket, vectorKey, maxParallelism, batchSize, rateLimit) }), }, {