Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions connectors/s3vector/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"golang.org/x/sync/errgroup"
"golang.org/x/time/rate"
)

type conn struct {
Expand All @@ -28,7 +27,7 @@ type conn struct {
vectorKey string
maxParallelism int
batchSize int
limiter *rate.Limiter
limiter util.NamespaceLimiter
}

// GeneratePlan implements adiomv1connect.ConnectorServiceHandler.
Expand Down Expand Up @@ -87,9 +86,10 @@ func (c *conn) WriteData(ctx context.Context, r *connect.Request[adiomv1.WriteDa

eg, ctx := errgroup.WithContext(ctx)
eg.SetLimit(c.maxParallelism)
limiter := c.limiter.Get(r.Msg.GetNamespace())
for batch := range slices.Chunk(vectors, c.batchSize) {
eg.Go(func() error {
if err := c.limiter.WaitN(ctx, len(batch)); err != nil {
if err := limiter.WaitN(ctx, len(batch)); err != nil {
return fmt.Errorf("err in limiter: %w", err)
}
if _, err := c.client.PutVectors(ctx, &s3vectors.PutVectorsInput{
Expand Down Expand Up @@ -220,11 +220,12 @@ func (c *conn) WriteUpdates(ctx context.Context, r *connect.Request[adiomv1.Writ
}
}
if len(updates) > 0 {
limiter := c.limiter.Get(r.Msg.GetNamespace())
eg, ctx := errgroup.WithContext(ctx)
eg.SetLimit(c.maxParallelism)
for batch := range slices.Chunk(toDelete, c.batchSize) {
eg.Go(func() error {
if err := c.limiter.WaitN(ctx, len(batch)); err != nil {
if err := limiter.WaitN(ctx, len(batch)); err != nil {
return fmt.Errorf("err in limiter: %w", err)
}
if _, err := c.client.DeleteVectors(ctx, &s3vectors.DeleteVectorsInput{
Expand All @@ -239,7 +240,7 @@ func (c *conn) WriteUpdates(ctx context.Context, r *connect.Request[adiomv1.Writ
}
for batch := range slices.Chunk(vectors, c.batchSize) {
eg.Go(func() error {
if err := c.limiter.WaitN(ctx, len(batch)); err != nil {
if err := limiter.WaitN(ctx, len(batch)); err != nil {
return fmt.Errorf("err in limiter: %w", err)
}
if _, err := c.client.PutVectors(ctx, &s3vectors.PutVectorsInput{
Expand All @@ -259,7 +260,7 @@ func (c *conn) WriteUpdates(ctx context.Context, r *connect.Request[adiomv1.Writ
return connect.NewResponse(&adiomv1.WriteUpdatesResponse{}), nil
}

func NewConn(bucketName string, vectorKey string, maxParallelism int, batchSize int, limiter *rate.Limiter) (*conn, error) {
func NewConn(bucketName string, vectorKey string, maxParallelism int, batchSize int, rateLimit int) (*conn, error) {
awsConfig, err := config.LoadDefaultConfig(context.Background())
if err != nil {
return nil, err
Expand All @@ -271,7 +272,7 @@ func NewConn(bucketName string, vectorKey string, maxParallelism int, batchSize
vectorKey: vectorKey,
maxParallelism: maxParallelism,
batchSize: batchSize,
limiter: limiter,
limiter: util.NewNamespaceLimiter(nil, rateLimit),
}, nil
}

Expand Down
57 changes: 57 additions & 0 deletions connectors/util/limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package util

import (
"sync"

"golang.org/x/time/rate"
)

type NamespaceLimiter interface {
Get(namespace string) *rate.Limiter
}

type namespaceLimiter struct {
mut sync.RWMutex
limiters map[string]*rate.Limiter
defaultLimiterFactory func() *rate.Limiter
}

func (l *namespaceLimiter) Get(namespace string) *rate.Limiter {
l.mut.RLock()
limiter, ok := l.limiters[namespace]
l.mut.RUnlock()
if !ok {
l.mut.Lock()
limiter2, ok2 := l.limiters[namespace]
if ok2 {
l.mut.Unlock()
return limiter2
}
limiter = l.defaultLimiterFactory()
l.limiters[namespace] = limiter
l.mut.Unlock()
return limiter
}
return limiter
}

func NewNamespaceLimiter(namespaceToLimit map[string]int, defaultLimit int) *namespaceLimiter {
limiters := map[string]*rate.Limiter{}
for k, v := range namespaceToLimit {
limit := rate.Limit(v)
if v < 0 {
limit = rate.Inf
}
limiters[k] = rate.NewLimiter(limit, v)
}
limit := rate.Limit(defaultLimit)
if defaultLimit < 0 {
limit = rate.Inf
}
return &namespaceLimiter{
limiters: limiters,
defaultLimiterFactory: func() *rate.Limiter {
return rate.NewLimiter(limit, defaultLimit)
},
}
}
16 changes: 4 additions & 12 deletions internal/app/options/connectorflags.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"golang.org/x/net/http2"
"golang.org/x/time/rate"
)

var ErrMissingConnector = errors.New("missing or unsupported connector")
Expand Down Expand Up @@ -380,28 +379,21 @@ func GetRegisteredConnectors() []RegisteredConnector {
},
&cli.IntFlag{
Name: "rate-limit",
Value: 500,
Usage: "Max vectors per second across workers on this process.",
},
&cli.IntFlag{
Name: "rate-limit-burst",
Value: 700,
Usage: "Max vectors per second across workers on this process (burst). Will be set to at least `rate-limit`.",
Value: 2500,
Usage: "Max vectors per second across workers on this process per namespace (vector index).",
},
&cli.IntFlag{
Name: "batch-size",
Value: 200,
Value: 500,
Usage: "Max size of each PutVector requests (aws hard limit is 500).",
},
}, func(c *cli.Context, args []string, _ AdditionalSettings) (adiomv1connect.ConnectorServiceHandler, error) {
bucket := c.String("bucket")
vectorKey := c.String("vector-key")
maxParallelism := c.Int("max-parallelism")
rateLimit := c.Int("rate-limit")
rateLimitBurst := max(c.Int("rate-limit-burst"), rateLimit)
batchSize := c.Int("batch-size")
limiter := rate.NewLimiter(rate.Limit(rateLimit), rateLimitBurst)
return s3vector.NewConn(bucket, vectorKey, maxParallelism, batchSize, limiter)
return s3vector.NewConn(bucket, vectorKey, maxParallelism, batchSize, rateLimit)
}),
},
{
Expand Down