diff --git a/centroid_index_test.go b/centroid_index_test.go new file mode 100644 index 000000000..a7334236b --- /dev/null +++ b/centroid_index_test.go @@ -0,0 +1,74 @@ +//go:build vectors +// +build vectors + +package bleve + +import ( + "encoding/json" + "fmt" + "os" + "testing" + + "github.com/blevesearch/bleve/v2/analysis/lang/en" + "github.com/blevesearch/bleve/v2/mapping" + index "github.com/blevesearch/bleve_index_api" +) + +func loadSiftData() ([]map[string]interface{}, error) { + fileContent, err := os.ReadFile("~/fts/data/datasets/vec-sift-bucket.json") + if err != nil { + return nil, err + } + var documents []map[string]interface{} + err = json.Unmarshal(fileContent, &documents) + if err != nil { + return nil, err + } + return documents, nil +} + +func TestCentroidIndex(t *testing.T) { + _, _, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + documents, err := loadSiftData() + if err != nil { + t.Fatal(err) + } + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Analyzer = en.AnalyzerName + + vecFieldMappingL2 := mapping.NewVectorFieldMapping() + vecFieldMappingL2.Dims = 128 + vecFieldMappingL2.Similarity = index.EuclideanDistance + + indexMappingL2Norm := NewIndexMapping() + indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping) + indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingL2) + + idx, err := newIndexUsing(t.TempDir(), indexMappingL2Norm, Config.DefaultIndexType, Config.DefaultKVStore, nil) + if err != nil { + t.Fatal(err) + } + defer func() { + err := idx.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch := idx.NewBatch() + for _, doc := range documents[:100000] { + docId := fmt.Sprintf("%s:%s", index.TrainDataPrefix, doc["id"]) + err = batch.Index(docId, doc) + if err != nil { + t.Fatal(err) + } + } + + err = idx.Train(batch) + if err != nil { + t.Fatal(err) + } +} diff --git a/go.mod b/go.mod index 5448bac80..c2ec2c4e6 100644 --- a/go.mod +++ b/go.mod @@ -19,13 +19,13 @@ require ( github.com/blevesearch/stempel v0.2.0 github.com/blevesearch/upsidedown_store_api v1.0.2 github.com/blevesearch/vellum v1.2.0 - github.com/blevesearch/zapx/v11 v11.4.2 - github.com/blevesearch/zapx/v12 v12.4.2 - github.com/blevesearch/zapx/v13 v13.4.2 - github.com/blevesearch/zapx/v14 v14.4.2 - github.com/blevesearch/zapx/v15 v15.4.2 - github.com/blevesearch/zapx/v16 v16.3.0 - github.com/blevesearch/zapx/v17 v17.0.1 + github.com/blevesearch/zapx/v11 v11.4.3 + github.com/blevesearch/zapx/v12 v12.4.3 + github.com/blevesearch/zapx/v13 v13.4.3 + github.com/blevesearch/zapx/v14 v14.4.3 + github.com/blevesearch/zapx/v15 v15.4.3 + github.com/blevesearch/zapx/v16 v16.3.1 + github.com/blevesearch/zapx/v17 v17.0.2-0.20260204210735-148661f2ddf6 github.com/couchbase/moss v0.2.0 github.com/spf13/cobra v1.10.2 go.etcd.io/bbolt v1.4.0 @@ -43,3 +43,25 @@ require ( github.com/spf13/pflag v1.0.9 // indirect golang.org/x/sys v0.40.0 // indirect ) + +replace github.com/blevesearch/bleve/v2 => /Users/thejas.orkombu/fts/blevesearch/bleve + +replace github.com/blevesearch/zapx/v11 => /Users/thejas.orkombu/fts/blevesearch/zapx11 + +replace github.com/blevesearch/zapx/v12 => /Users/thejas.orkombu/fts/blevesearch/zapx12 + +replace github.com/blevesearch/zapx/v13 => /Users/thejas.orkombu/fts/blevesearch/zapx13 + +replace github.com/blevesearch/zapx/v14 => /Users/thejas.orkombu/fts/blevesearch/zapx14 + +replace github.com/blevesearch/zapx/v15 => /Users/thejas.orkombu/fts/blevesearch/zapx15 + +replace github.com/blevesearch/zapx/v16 => /Users/thejas.orkombu/fts/blevesearch/zapx + +replace github.com/blevesearch/scorch_segment_api/v2 => /Users/thejas.orkombu/fts/blevesearch/scorch_segment_api + +replace github.com/blevesearch/go-faiss => /Users/thejas.orkombu/fts/blevesearch/go-faiss + +replace github.com/blevesearch/bleve_index_api => /Users/thejas.orkombu/fts/blevesearch/bleve_index_api + +replace github.com/blevesearch/sear => /Users/thejas.orkombu/fts/blevesearch/sear diff --git a/go.sum b/go.sum index 8207f8975..b8eba81f8 100644 --- a/go.sum +++ b/go.sum @@ -33,20 +33,20 @@ github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMG github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ= github.com/blevesearch/vellum v1.2.0 h1:xkDiOEsHc2t3Cp0NsNZZ36pvc130sCzcGKOPMzXe+e0= github.com/blevesearch/vellum v1.2.0/go.mod h1:uEcfBJz7mAOf0Kvq6qoEKQQkLODBF46SINYNkZNae4k= -github.com/blevesearch/zapx/v11 v11.4.2 h1:l46SV+b0gFN+Rw3wUI1YdMWdSAVhskYuvxlcgpQFljs= -github.com/blevesearch/zapx/v11 v11.4.2/go.mod h1:4gdeyy9oGa/lLa6D34R9daXNUvfMPZqUYjPwiLmekwc= -github.com/blevesearch/zapx/v12 v12.4.2 h1:fzRbhllQmEMUuAQ7zBuMvKRlcPA5ESTgWlDEoB9uQNE= -github.com/blevesearch/zapx/v12 v12.4.2/go.mod h1:TdFmr7afSz1hFh/SIBCCZvcLfzYvievIH6aEISCte58= -github.com/blevesearch/zapx/v13 v13.4.2 h1:46PIZCO/ZuKZYgxI8Y7lOJqX3Irkc3N8W82QTK3MVks= -github.com/blevesearch/zapx/v13 v13.4.2/go.mod h1:knK8z2NdQHlb5ot/uj8wuvOq5PhDGjNYQQy0QDnopZk= -github.com/blevesearch/zapx/v14 v14.4.2 h1:2SGHakVKd+TrtEqpfeq8X+So5PShQ5nW6GNxT7fWYz0= -github.com/blevesearch/zapx/v14 v14.4.2/go.mod h1:rz0XNb/OZSMjNorufDGSpFpjoFKhXmppH9Hi7a877D8= -github.com/blevesearch/zapx/v15 v15.4.2 h1:sWxpDE0QQOTjyxYbAVjt3+0ieu8NCE0fDRaFxEsp31k= -github.com/blevesearch/zapx/v15 v15.4.2/go.mod h1:1pssev/59FsuWcgSnTa0OeEpOzmhtmr/0/11H0Z8+Nw= -github.com/blevesearch/zapx/v16 v16.3.0 h1:hF6VlN15E9CB40RMPyqOIhlDw1OOo9RItumhKMQktxw= -github.com/blevesearch/zapx/v16 v16.3.0/go.mod h1:zCFjv7McXWm1C8rROL+3mUoD5WYe2RKsZP3ufqcYpLY= -github.com/blevesearch/zapx/v17 v17.0.1 h1:kdojyNDiC4abVvsSwequvqYTBuLEXoG3c0UKyxe1+GM= -github.com/blevesearch/zapx/v17 v17.0.1/go.mod h1:gvr+JMDB9XvQUkT+CaYJhY7aMlez5EmXbkzOBCVyc7U= +github.com/blevesearch/zapx/v11 v11.4.3 h1:PTZOO5loKpHC/x/GzmPZNa9cw7GZIQxd5qRjwij9tHY= +github.com/blevesearch/zapx/v11 v11.4.3/go.mod h1:4gdeyy9oGa/lLa6D34R9daXNUvfMPZqUYjPwiLmekwc= +github.com/blevesearch/zapx/v12 v12.4.3 h1:eElXvAaAX4m04t//CGBQAtHNPA+Q6A1hHZVrN3LSFYo= +github.com/blevesearch/zapx/v12 v12.4.3/go.mod h1:TdFmr7afSz1hFh/SIBCCZvcLfzYvievIH6aEISCte58= +github.com/blevesearch/zapx/v13 v13.4.3 h1:qsdhRhaSpVnqDFlRiH9vG5+KJ+dE7KAW9WyZz/KXAiE= +github.com/blevesearch/zapx/v13 v13.4.3/go.mod h1:knK8z2NdQHlb5ot/uj8wuvOq5PhDGjNYQQy0QDnopZk= +github.com/blevesearch/zapx/v14 v14.4.3 h1:GY4Hecx0C6UTmiNC2pKdeA2rOKiLR5/rwpU9WR51dgM= +github.com/blevesearch/zapx/v14 v14.4.3/go.mod h1:rz0XNb/OZSMjNorufDGSpFpjoFKhXmppH9Hi7a877D8= +github.com/blevesearch/zapx/v15 v15.4.3 h1:iJiMJOHrz216jyO6lS0m9RTCEkprUnzvqAI2lc/0/CU= +github.com/blevesearch/zapx/v15 v15.4.3/go.mod h1:1pssev/59FsuWcgSnTa0OeEpOzmhtmr/0/11H0Z8+Nw= +github.com/blevesearch/zapx/v16 v16.3.1 h1:ERxZUSC9UcuKggCQ6b3y4sTkyL4WnGOWuopzglR874g= +github.com/blevesearch/zapx/v16 v16.3.1/go.mod h1:zCFjv7McXWm1C8rROL+3mUoD5WYe2RKsZP3ufqcYpLY= +github.com/blevesearch/zapx/v17 v17.0.2-0.20260204210735-148661f2ddf6 h1:eqJh5al0dcPq6VsY6C+G4kva5BBffzMG+sN/SWg2/Eg= +github.com/blevesearch/zapx/v17 v17.0.2-0.20260204210735-148661f2ddf6/go.mod h1:gvr+JMDB9XvQUkT+CaYJhY7aMlez5EmXbkzOBCVyc7U= github.com/couchbase/ghistogram v0.1.0 h1:b95QcQTCzjTUocDXp/uMgSNQi8oj1tGwnJ4bODWZnps= github.com/couchbase/ghistogram v0.1.0/go.mod h1:s1Jhy76zqfEecpNWJfWUiKZookAFaiGOEoyzgHt9i7k= github.com/couchbase/moss v0.2.0 h1:VCYrMzFwEryyhRSeI+/b3tRBSeTpi/8gn5Kf6dxqn+o= diff --git a/index.go b/index.go index 2f1ba5fbf..c083787c4 100644 --- a/index.go +++ b/index.go @@ -396,3 +396,7 @@ type InsightsIndex interface { // CentroidCardinalities returns the centroids (clusters) from IVF indexes ordered by data density. CentroidCardinalities(field string, limit int, desceding bool) ([]index.CentroidCardinality, error) } +type VectorIndex interface { + Index + Train(*Batch) error +} diff --git a/index/scorch/merge.go b/index/scorch/merge.go index e17288410..bca9bbb81 100644 --- a/index/scorch/merge.go +++ b/index/scorch/merge.go @@ -372,8 +372,8 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, atomic.AddUint64(&s.stats.TotFileMergeZapBeg, 1) prevBytesReadTotal := cumulateBytesRead(segmentsToMerge) - newDocNums, _, err := s.segPlugin.Merge(segmentsToMerge, docsToDrop, path, - cw.cancelCh, s) + newDocNums, _, err := s.segPlugin.MergeUsing(segmentsToMerge, docsToDrop, path, + cw.cancelCh, s, s.segmentConfig) atomic.AddUint64(&s.stats.TotFileMergeZapEnd, 1) fileMergeZapTime := uint64(time.Since(fileMergeZapStartTime)) @@ -391,7 +391,7 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, return fmt.Errorf("merging failed: %v", err) } - seg, err = s.segPlugin.Open(path) + seg, err = s.segPlugin.OpenUsing(path, s.segmentConfig) if err != nil { s.unmarkIneligibleForRemoval(filename) atomic.AddUint64(&s.stats.TotFileMergePlanTasksErr, 1) @@ -540,7 +540,7 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, // the newly merged segment is already flushed out to disk, just needs // to be opened using mmap. newDocIDs, _, err := - s.segPlugin.Merge(segsBatch, dropsBatch, path, s.closeCh, s) + s.segPlugin.MergeUsing(segsBatch, dropsBatch, path, s.closeCh, s, s.segmentConfig) if err != nil { em.Lock() errs = append(errs, err) @@ -555,7 +555,7 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, s.markIneligibleForRemoval(filename) newMergedSegmentIDs[id] = newSegmentID newDocIDsSet[id] = newDocIDs - newMergedSegments[id], err = s.segPlugin.Open(path) + newMergedSegments[id], err = s.segPlugin.OpenUsing(path, s.segmentConfig) if err != nil { em.Lock() errs = append(errs, err) diff --git a/index/scorch/persister.go b/index/scorch/persister.go index d0c013a1d..3df4ac2e6 100644 --- a/index/scorch/persister.go +++ b/index/scorch/persister.go @@ -425,7 +425,6 @@ func (s *Scorch) persistSnapshotMaybeMerge(snapshot *IndexSnapshot, po *persiste var totSize int var numSegsToFlushOut int var totDocs uint64 - // legacy behaviour of merge + flush of all in-memory segments in one-shot if legacyFlushBehaviour(po.MaxSizeInMemoryMergePerWorker, po.NumPersisterWorkers) { val := &flushable{ @@ -804,7 +803,7 @@ func (s *Scorch) persistSnapshotDirect(snapshot *IndexSnapshot, exclude map[uint } }() for segmentID, path := range newSegmentPaths { - newSegments[segmentID], err = s.segPlugin.Open(path) + newSegments[segmentID], err = s.segPlugin.OpenUsing(path, s.segmentConfig) if err != nil { return fmt.Errorf("error opening new segment at %s, %v", path, err) } @@ -853,6 +852,10 @@ func zapFileName(epoch uint64) string { return fmt.Sprintf("%012x.zap", epoch) } +func (s *Scorch) loadTrainedData(bucket *bolt.Bucket) error { + return s.trainer.loadTrainedData(bucket) +} + // bolt snapshot code func (s *Scorch) loadFromBolt() error { @@ -904,6 +907,12 @@ func (s *Scorch) loadFromBolt() error { foundRoot = true } + + trainerBucket := snapshots.Bucket(util.BoltTrainerKey) + err := s.trainer.loadTrainedData(trainerBucket) + if err != nil { + return err + } return nil }) if err != nil { @@ -1016,7 +1025,7 @@ func (s *Scorch) loadSegment(segmentBucket *bolt.Bucket) (*SegmentSnapshot, erro return nil, fmt.Errorf("segment path missing") } segmentPath := s.path + string(os.PathSeparator) + string(pathBytes) - seg, err := s.segPlugin.Open(segmentPath) + seg, err := s.segPlugin.OpenUsing(segmentPath, s.segmentConfig) if err != nil { return nil, fmt.Errorf("error opening bolt segment: %v", err) } diff --git a/index/scorch/scorch.go b/index/scorch/scorch.go index 329de598e..afe2878f0 100644 --- a/index/scorch/scorch.go +++ b/index/scorch/scorch.go @@ -45,6 +45,7 @@ type Scorch struct { readOnly bool version uint8 config map[string]interface{} + segmentConfig map[string]interface{} analysisQueue *index.AnalysisQueue path string @@ -78,6 +79,8 @@ type Scorch struct { rootBolt *bolt.DB asyncTasks sync.WaitGroup + trainer trainer + onEvent func(event Event) bool onAsyncError func(err error, path string) @@ -88,6 +91,29 @@ type Scorch struct { spatialPlugin index.SpatialAnalyzerPlugin } +// trainer interface is used for training an index that has the concept +// of "learning". Naturally, a vector index is one such thing that would +// implement this interface. There can be multiple implementations of the +// training itself even for the same index type. +// +// this component is not supposed to interact with the other master routines +// of scorch and will be used only for training the index before the actual data +// ingestion starts. The routine should also be released once the +// training is marked as complete - which can be done using the BoltTrainCompleteKey +// key and a bool value. However the struct is still maintained for the pointer to +// the instance so that we can use in the later stages of the index lifecycle. +type trainer interface { + // ephemeral + trainLoop() + // for the training state and the ingestion of the samples + train(batch *index.Batch) error + + // to load the metadata from the bolt under the BoltTrainerKey + loadTrainedData(*bolt.Bucket) error + // to fetch the internal data from the component + getInternal(key []byte) ([]byte, error) +} + type ScorchErrorType string func (t ScorchErrorType) Error() string { @@ -154,6 +180,7 @@ func NewScorch(storeName string, forceMergeRequestCh: make(chan *mergerCtrl, 1), segPlugin: defaultSegmentPlugin, copyScheduled: map[string]int{}, + segmentConfig: make(map[string]interface{}), } forcedSegmentType, forcedSegmentVersion, err := configForceSegmentTypeVersion(config) @@ -168,6 +195,11 @@ func NewScorch(storeName string, } } + segConfig, ok := config["segmentConfig"].(map[string]interface{}) + if ok { + rv.segmentConfig = segConfig + } + typ, ok := config["spatialPlugin"].(string) if ok { if err := rv.loadSpatialAnalyzerPlugin(typ); err != nil { @@ -205,6 +237,8 @@ func NewScorch(storeName string, return nil, err } + rv.trainer = initTrainer(rv) + return rv, nil } @@ -259,6 +293,9 @@ func (s *Scorch) Open() error { s.asyncTasks.Add(1) go s.introducerLoop() + s.asyncTasks.Add(1) + go s.trainer.trainLoop() + if !s.readOnly && s.path != "" { s.asyncTasks.Add(1) go s.persisterLoop() @@ -497,7 +534,7 @@ func (s *Scorch) Batch(batch *index.Batch) (err error) { stats := newFieldStats() if len(analysisResults) > 0 { - newSegment, bufBytes, err = s.segPlugin.New(analysisResults) + newSegment, bufBytes, err = s.segPlugin.NewUsing(analysisResults, s.segmentConfig) if err != nil { return err } @@ -532,6 +569,21 @@ func (s *Scorch) Batch(batch *index.Batch) (err error) { return err } +func (s *Scorch) getInternal(key []byte) ([]byte, error) { + s.rootLock.RLock() + defer s.rootLock.RUnlock() + + switch string(key) { + case string(util.BoltTrainCompleteKey): + return s.trainer.getInternal(key) + } + return nil, nil +} + +func (s *Scorch) Train(batch *index.Batch) error { + return s.trainer.train(batch) +} + func (s *Scorch) prepareSegment(newSegment segment.Segment, ids []string, internalOps map[string][]byte, persistedCallback index.BatchCallback, stats *fieldStats, ) error { diff --git a/index/scorch/segment_plugin.go b/index/scorch/segment_plugin.go index c44f9cf7b..16be8e440 100644 --- a/index/scorch/segment_plugin.go +++ b/index/scorch/segment_plugin.go @@ -46,10 +46,14 @@ type SegmentPlugin interface { // New takes a set of Documents and turns them into a new Segment New(results []index.Document) (segment.Segment, uint64, error) + NewUsing(results []index.Document, config map[string]interface{}) (segment.Segment, uint64, error) + // Open attempts to open the file at the specified path and // return the corresponding Segment Open(path string) (segment.Segment, error) + OpenUsing(path string, config map[string]interface{}) (segment.Segment, error) + // Merge takes a set of Segments, and creates a new segment on disk at // the specified path. // Drops is a set of bitmaps (one for each segment) indicating which @@ -67,6 +71,10 @@ type SegmentPlugin interface { Merge(segments []segment.Segment, drops []*roaring.Bitmap, path string, closeCh chan struct{}, s segment.StatsReporter) ( [][]uint64, uint64, error) + + MergeUsing(segments []segment.Segment, drops []*roaring.Bitmap, path string, + closeCh chan struct{}, s segment.StatsReporter, config map[string]interface{}) ( + [][]uint64, uint64, error) } var supportedSegmentPlugins map[string]map[uint32]SegmentPlugin diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 3422d9a14..688f9d903 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -466,6 +466,10 @@ func (is *IndexSnapshot) Fields() ([]string, error) { } func (is *IndexSnapshot) GetInternal(key []byte) ([]byte, error) { + _, ok := is.internal[string(key)] + if !ok { + return is.parent.getInternal(key) + } return is.internal[string(key)], nil } diff --git a/index/scorch/train_noop.go b/index/scorch/train_noop.go new file mode 100644 index 000000000..d82b342c6 --- /dev/null +++ b/index/scorch/train_noop.go @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !vectors +// +build !vectors + +package scorch + +import ( + "fmt" + + index "github.com/blevesearch/bleve_index_api" + bolt "go.etcd.io/bbolt" +) + +func initTrainer(s *Scorch) *noopTrainer { + return &noopTrainer{} +} + +type noopTrainer struct { +} + +func (t *noopTrainer) trainLoop() {} + +func (t *noopTrainer) train(batch *index.Batch) error { + return fmt.Errorf("training is not supported with this build") +} + +func (t *noopTrainer) loadTrainedData(bucket *bolt.Bucket) error { + // noop + return nil +} + +func (t *noopTrainer) getInternal(key []byte) ([]byte, error) { + return nil, nil +} diff --git a/index/scorch/train_vector.go b/index/scorch/train_vector.go new file mode 100644 index 000000000..74cc6b4ed --- /dev/null +++ b/index/scorch/train_vector.go @@ -0,0 +1,268 @@ +// Copyright (c) 2026 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package scorch + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/RoaringBitmap/roaring/v2" + "github.com/blevesearch/bleve/v2/util" + index "github.com/blevesearch/bleve_index_api" + "github.com/blevesearch/go-faiss" + segment "github.com/blevesearch/scorch_segment_api/v2" + bolt "go.etcd.io/bbolt" +) + +type trainRequest struct { + sample segment.Segment + vecCount int + ackCh chan error +} + +func initTrainer(s *Scorch) *vectorTrainer { + return &vectorTrainer{ + parent: s, + trainCh: make(chan *trainRequest), + } +} + +type vectorTrainer struct { + parent *Scorch + + m sync.Mutex + // not a searchable segment in the sense that it won't return + // the data vectors. can return centroid vectors + centroidIndex *SegmentSnapshot + trainCh chan *trainRequest +} + +func moveFile(sourcePath, destPath string) error { + // rename is supposed to be atomic on the same filesystem + err := os.Rename(sourcePath, destPath) + if err != nil { + return fmt.Errorf("error renaming file: %v", err) + } + return nil +} + +// this is not a routine that will be running throughout the lifetime of the index. It's purpose +// is to only train the vector index before the data ingestion starts. +func (t *vectorTrainer) trainLoop() { + defer func() { + t.parent.asyncTasks.Done() + }() + // initialize stuff + t.parent.segmentConfig[index.CentroidIndexCallback] = t.getCentroidIndex + path := filepath.Join(t.parent.path, index.CentroidIndexFileName) + for { + select { + case <-t.parent.closeCh: + return + case trainReq := <-t.trainCh: + sampleSeg := trainReq.sample + if t.centroidIndex == nil { + switch seg := sampleSeg.(type) { + case segment.UnpersistedSegment: + err := persistToDirectory(seg, nil, path) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error persisting segment: %v", err) + close(trainReq.ackCh) + return + } + default: + } + } else { + // merge the new segment with the existing one, no need to persist? + // persist in a tmp file and then rename - is that a fair strategy? + t.parent.segmentConfig[index.TrainingKey] = true + _, _, err := t.parent.segPlugin.MergeUsing([]segment.Segment{t.centroidIndex.segment, sampleSeg}, + []*roaring.Bitmap{nil, nil}, path+".tmp", t.parent.closeCh, nil, t.parent.segmentConfig) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error merging centroid index: %v", err) + close(trainReq.ackCh) + } + // reset the training flag once completed + t.parent.segmentConfig[index.TrainingKey] = false + + // close the existing centroid segment - it's supposed to be gc'd at this point + t.centroidIndex.segment.Close() + err = moveFile(path+".tmp", path) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error renaming centroid index: %v", err) + close(trainReq.ackCh) + } + } + // a bolt transaction is necessary for failover-recovery scenario and also serves as a checkpoint + // where we can be sure that the centroid index is available for the indexing operations downstream + // + // note: when the scale increases massively especially with real world dimensions of 1536+, this API + // will have to be refactored to persist in a more resource efficient way. so having this bolt related + // code will help in tracking the progress a lot better and avoid any redudant data streaming operations. + // + // todo: rethink the frequency of bolt writes + tx, err := t.parent.rootBolt.Begin(true) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error starting bolt transaction: %v", err) + close(trainReq.ackCh) + return + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + + snapshotsBucket, err := tx.CreateBucketIfNotExists(util.BoltSnapshotsBucket) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error creating snapshots bucket: %v", err) + close(trainReq.ackCh) + return + } + + trainerBucket, err := snapshotsBucket.CreateBucketIfNotExists(util.BoltTrainerKey) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error creating centroid bucket: %v", err) + close(trainReq.ackCh) + return + } + + err = trainerBucket.Put(util.BoltPathKey, []byte(index.CentroidIndexFileName)) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error updating centroid bucket: %v", err) + close(trainReq.ackCh) + return + } + + err = tx.Commit() + if err != nil { + trainReq.ackCh <- fmt.Errorf("error committing bolt transaction: %v", err) + close(trainReq.ackCh) + return + } + + err = t.parent.rootBolt.Sync() + if err != nil { + trainReq.ackCh <- fmt.Errorf("error on bolt sync: %v", err) + close(trainReq.ackCh) + return + } + + // update the centroid index pointer + centroidIndex, err := t.parent.segPlugin.OpenUsing(path, t.parent.segmentConfig) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error opening centroid index: %v", err) + close(trainReq.ackCh) + return + } + t.m.Lock() + t.centroidIndex = &SegmentSnapshot{ + segment: centroidIndex, + } + t.m.Unlock() + close(trainReq.ackCh) + } + } +} + +// loads the metadata specific to the centroid index from boltdb +func (t *vectorTrainer) loadTrainedData(bucket *bolt.Bucket) error { + if bucket == nil { + return nil + } + segmentSnapshot, err := t.parent.loadSegment(bucket) + if err != nil { + return err + } + t.m.Lock() + defer t.m.Unlock() + t.centroidIndex = segmentSnapshot + return nil +} + +func (t *vectorTrainer) train(batch *index.Batch) error { + // regulate the Train function + t.parent.FireIndexEvent() + + var trainData []index.Document + for key, doc := range batch.IndexOps { + if doc != nil { + // insert _id field + // no need to track updates/deletes over here since + // the API is singleton + doc.AddIDField() + } + if strings.HasPrefix(key, index.TrainDataPrefix) { + trainData = append(trainData, doc) + } + } + + // just builds a new vector index out of the train data provided + // this is not necessarily the final train data since this is submitted + // as a request to the trainer component to be merged. once the training + // is complete, the template will be used for other operations down the line + // like merge and search. + // + // note: this might index text data too, how to handle this? s.segmentConfig? + // todo: updates/deletes -> data drift detection + seg, _, err := t.parent.segPlugin.NewUsing(trainData, t.parent.segmentConfig) + if err != nil { + return err + } + + trainReq := &trainRequest{ + sample: seg, + vecCount: len(trainData), // todo: multivector support + ackCh: make(chan error), + } + + t.trainCh <- trainReq + err = <-trainReq.ackCh + if err != nil { + return fmt.Errorf("train_vector: train() err'd out with: %w", err) + } + + return err +} + +func (t *vectorTrainer) getInternal(key []byte) ([]byte, error) { + // todo: return the total number of vectors that have been processed so far in training + // in cbft use that as a checkpoint to resume training for n-x samples. + switch string(key) { + case string(util.BoltTrainCompleteKey): + return []byte(fmt.Sprintf("%t", t.centroidIndex != nil)), nil + } + return nil, nil +} + +func (t *vectorTrainer) getCentroidIndex(field string) (*faiss.IndexImpl, error) { + // return the coarse quantizer of the centroid index belonging to the field + centroidIndexSegment, ok := t.centroidIndex.segment.(segment.CentroidIndexSegment) + if !ok { + return nil, fmt.Errorf("segment is not a centroid index segment") + } + + coarseQuantizer, err := centroidIndexSegment.GetCoarseQuantizer(field) + if err != nil { + return nil, err + } + return coarseQuantizer, nil +} diff --git a/index_alias_impl.go b/index_alias_impl.go index 8212c74b9..2839752e2 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -103,6 +103,24 @@ func (i *indexAliasImpl) IndexSynonym(id string, collection string, definition * return ErrorSynonymSearchNotSupported } +func (i *indexAliasImpl) Train(batch *Batch) error { + i.mutex.RLock() + defer i.mutex.RUnlock() + if !i.open { + return ErrorIndexClosed + } + + err := i.isAliasToSingleIndex() + if err != nil { + return err + } + + if vi, ok := i.indexes[0].(VectorIndex); ok { + return vi.Train(batch) + } + return fmt.Errorf("not a vector index") +} + func (i *indexAliasImpl) Delete(id string) error { i.mutex.RLock() defer i.mutex.RUnlock() diff --git a/index_impl.go b/index_impl.go index 586dacb3b..0d7e1dd4d 100644 --- a/index_impl.go +++ b/index_impl.go @@ -369,6 +369,20 @@ func (i *indexImpl) IndexSynonym(id string, collection string, definition *Synon return err } +func (i *indexImpl) Train(batch *Batch) error { + i.mutex.RLock() + defer i.mutex.RUnlock() + + if !i.open { + return ErrorIndexClosed + } + + if vi, ok := i.i.(index.VectorIndex); ok { + return vi.Train(batch.internal) + } + return fmt.Errorf("not a vector index") +} + // IndexAdvanced takes a document.Document object // skips the mapping and indexes it. func (i *indexImpl) IndexAdvanced(doc *document.Document) (err error) { @@ -1442,7 +1456,7 @@ func (i *indexImpl) CopyTo(d index.Directory) (err error) { err = copyReader.CopyTo(d) if err != nil { - return fmt.Errorf("error copying index metadata: %v", err) + return fmt.Errorf("error copying index data: %v", err) } // copy the metadata diff --git a/util/keys.go b/util/keys.go index b71a7f48b..ce8965da2 100644 --- a/util/keys.go +++ b/util/keys.go @@ -17,6 +17,8 @@ package util var ( // Bolt keys BoltSnapshotsBucket = []byte{'s'} + BoltTrainerKey = []byte{'t'} + BoltTrainCompleteKey = []byte{'c'} BoltPathKey = []byte{'p'} BoltDeletedKey = []byte{'d'} BoltInternalKey = []byte{'i'}