diff --git a/document/field_vector.go b/document/field_vector.go index 4c20013c7..4e142b03b 100644 --- a/document/field_vector.go +++ b/document/field_vector.go @@ -114,6 +114,13 @@ func NewVectorFieldWithIndexingOptions(name string, arrayPositions []uint64, // skip freq/norms for vector field options |= index.SkipFreqNorm + // bivf-flat indexes only supports hamming distance for the primary + // binary index. Similarity here is used for the backing flat index, + // which is set to cosine similarity for recall reasons + if vectorIndexOptimizedFor == index.IndexOptimizedWithBivfFlat { + similarity = index.CosineSimilarity + } + return &VectorField{ name: name, dims: dims, diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index 7c7ff1b98..5acb90b98 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -151,6 +151,12 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, if vectorIndexOptimizedFor == "" { vectorIndexOptimizedFor = index.DefaultIndexOptimization } + // bivf-flat indexes only supports hamming distance for the primary + // binary index. Similarity here is used for the backing flat index, + // which is set to cosine similarity for recall reasons + if vectorIndexOptimizedFor == index.IndexOptimizedWithBivfFlat { + similarity = index.CosineSimilarity + } // normalize raw vector if similarity is cosine // Since the vector can be multi-vector (flattened array of multiple vectors), // we use NormalizeMultiVector to normalize each sub-vector independently. @@ -185,6 +191,12 @@ func (fm *FieldMapping) processVectorBase64(propertyMightBeVectorBase64 interfac if vectorIndexOptimizedFor == "" { vectorIndexOptimizedFor = index.DefaultIndexOptimization } + // bivf-flat indexes only supports hamming distance for the primary + // binary index. Similarity here is used for the backing flat index, + // which is set to cosine similarity for recall reasons + if vectorIndexOptimizedFor == index.IndexOptimizedWithBivfFlat { + similarity = index.CosineSimilarity + } decodedVector, err := document.DecodeVector(encodedString) if err != nil || len(decodedVector) != fm.Dims { return @@ -289,6 +301,11 @@ func validateVectorFieldAlias(field *FieldMapping, path []string, effectiveOptimizedFor, reflect.ValueOf(index.SupportedVectorIndexOptimizations).MapKeys()) } + // bivf-flat's primary indexes requires vector dimensionality to be a multiple of 8 + if effectiveOptimizedFor == index.IndexOptimizedWithBivfFlat && field.Dims%8 != 0 { + return fmt.Errorf("field: '%s', incompatible vector dimensionality for BIVF-FLAT: %d,"+ + " dimension should be a multiple of 8", effectiveFieldName, field.Dims) + } if fieldAliasCtx != nil { // writing to a nil map is unsafe fieldAliasCtx[effectiveFieldName] = field diff --git a/search/query/knn.go b/search/query/knn.go index ea8780a41..e026979af 100644 --- a/search/query/knn.go +++ b/search/query/knn.go @@ -84,6 +84,12 @@ func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader, if q.K <= 0 || len(q.Vector) == 0 { return nil, fmt.Errorf("k must be greater than 0 and vector must be non-empty") } + // bivf-flat indexes only supports hamming distance for the primary + // binary index. Similarity here is used for the backing flat index, + // which is set to cosine similarity for recall reasons + if fieldMapping.VectorIndexOptimizedFor == index.IndexOptimizedWithBivfFlat { + similarityMetric = index.CosineSimilarity + } if similarityMetric == index.CosineSimilarity { // normalize the vector q.Vector = mapping.NormalizeVector(q.Vector) diff --git a/search_knn_test.go b/search_knn_test.go index d053705ca..5df43326f 100644 --- a/search_knn_test.go +++ b/search_knn_test.go @@ -608,6 +608,275 @@ func TestVectorBase64Index(t *testing.T) { } } +// Test to verify that the BIVF-Flat index with vector base64 field mapping returns the +// same results as the non-optimized vector field mapping for L2, Dot Product and Cosine similarities. +// Also test to see no differences in results for any distance metric +func TestVectorBivfFlatIndex(t *testing.T) { + + dataset, searchRequests, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + documents := makeDatasetIntoDocuments(dataset) + + _, searchRequestsCopy, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + + for _, doc := range documents { + vec, ok := doc["vector"].([]float32) + if !ok { + t.Fatal("Typecasting vector to float array failed") + } + + buf := new(bytes.Buffer) + for _, v := range vec { + err := binary.Write(buf, binary.LittleEndian, v) + if err != nil { + t.Fatal(err) + } + } + + doc["vectorEncoded"] = base64.StdEncoding.EncodeToString(buf.Bytes()) + } + + for _, sr := range searchRequestsCopy { + for _, kr := range sr.KNN { + kr.Field = "vectorEncoded" + } + } + + contentFM := NewTextFieldMapping() + contentFM.Analyzer = en.AnalyzerName + + vecFML2 := mapping.NewVectorFieldMapping() + vecFML2.Dims = testDatasetDims + vecFML2.Similarity = index.EuclideanDistance + vecFML2.VectorIndexOptimizedFor = index.IndexOptimizedWithBivfFlat + + vecBFML2 := mapping.NewVectorBase64FieldMapping() + vecBFML2.Dims = testDatasetDims + vecBFML2.Similarity = index.EuclideanDistance + vecBFML2.VectorIndexOptimizedFor = index.IndexOptimizedWithBivfFlat + + vecFMDot := mapping.NewVectorFieldMapping() + vecFMDot.Dims = testDatasetDims + vecFMDot.Similarity = index.InnerProduct + vecFMDot.VectorIndexOptimizedFor = index.IndexOptimizedWithBivfFlat + + vecBFMDot := mapping.NewVectorBase64FieldMapping() + vecBFMDot.Dims = testDatasetDims + vecBFMDot.Similarity = index.InnerProduct + vecBFMDot.VectorIndexOptimizedFor = index.IndexOptimizedWithBivfFlat + + vecFMCosine := mapping.NewVectorFieldMapping() + vecFMCosine.Dims = testDatasetDims + vecFMCosine.Similarity = index.CosineSimilarity + + vecBFMCosine := mapping.NewVectorBase64FieldMapping() + vecBFMCosine.Dims = testDatasetDims + vecBFMCosine.Similarity = index.CosineSimilarity + vecBFMCosine.VectorIndexOptimizedFor = index.IndexOptimizedWithBivfFlat + + indexMappingL2 := NewIndexMapping() + indexMappingL2.DefaultMapping.AddFieldMappingsAt("content", contentFM) + indexMappingL2.DefaultMapping.AddFieldMappingsAt("vector", vecFML2) + indexMappingL2.DefaultMapping.AddFieldMappingsAt("vectorEncoded", vecBFML2) + + indexMappingDot := NewIndexMapping() + indexMappingDot.DefaultMapping.AddFieldMappingsAt("content", contentFM) + indexMappingDot.DefaultMapping.AddFieldMappingsAt("vector", vecFMDot) + indexMappingDot.DefaultMapping.AddFieldMappingsAt("vectorEncoded", vecBFMDot) + + indexMappingCosine := NewIndexMapping() + indexMappingCosine.DefaultMapping.AddFieldMappingsAt("content", contentFM) + indexMappingCosine.DefaultMapping.AddFieldMappingsAt("vector", vecFMCosine) + indexMappingCosine.DefaultMapping.AddFieldMappingsAt("vectorEncoded", vecBFMCosine) + + tmpIndexPathL2 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPathL2) + + tmpIndexPathDot := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPathDot) + + tmpIndexPathCosine := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPathCosine) + + indexL2, err := New(tmpIndexPathL2, indexMappingL2) + if err != nil { + t.Fatal(err) + } + defer func() { + err := indexL2.Close() + if err != nil { + t.Fatal(err) + } + }() + + indexDot, err := New(tmpIndexPathDot, indexMappingDot) + if err != nil { + t.Fatal(err) + } + defer func() { + err := indexDot.Close() + if err != nil { + t.Fatal(err) + } + }() + + indexCosine, err := New(tmpIndexPathCosine, indexMappingCosine) + if err != nil { + t.Fatal(err) + } + defer func() { + err := indexCosine.Close() + if err != nil { + t.Fatal(err) + } + }() + + batchL2 := indexL2.NewBatch() + batchDot := indexDot.NewBatch() + batchCosine := indexCosine.NewBatch() + + for _, doc := range documents { + err = batchL2.Index(doc["id"].(string), doc) + if err != nil { + t.Fatal(err) + } + err = batchDot.Index(doc["id"].(string), doc) + if err != nil { + t.Fatal(err) + } + err = batchCosine.Index(doc["id"].(string), doc) + if err != nil { + t.Fatal(err) + } + } + + err = indexL2.Batch(batchL2) + if err != nil { + t.Fatal(err) + } + + err = indexDot.Batch(batchDot) + if err != nil { + t.Fatal(err) + } + + err = indexCosine.Batch(batchCosine) + if err != nil { + t.Fatal(err) + } + + for i := range searchRequests { + for _, operator := range knnOperators { + normQuery := searchRequests[i] + base64Query := searchRequestsCopy[i] + + normQuery.AddKNNOperator(operator) + base64Query.AddKNNOperator(operator) + + normResultL2, err := indexL2.Search(normQuery) + if err != nil { + t.Fatal(err) + } + base64ResultL2, err := indexL2.Search(base64Query) + if err != nil { + t.Fatal(err) + } + + if normResultL2 != nil && base64ResultL2 != nil { + if len(normResultL2.Hits) == len(base64ResultL2.Hits) { + for j := range normResultL2.Hits { + if normResultL2.Hits[j].ID != base64ResultL2.Hits[j].ID { + t.Fatalf("testcase %d failed: expected hit id %s, got hit id %s", i, normResultL2.Hits[j].ID, base64ResultL2.Hits[j].ID) + } + } + } + } else if (normResultL2 == nil && base64ResultL2 != nil) || + (normResultL2 != nil && base64ResultL2 == nil) { + t.Fatalf("testcase %d failed: expected result %s, got result %s", i, normResultL2, base64ResultL2) + } + + normResultDot, err := indexDot.Search(normQuery) + if err != nil { + t.Fatal(err) + } + base64ResultDot, err := indexDot.Search(base64Query) + if err != nil { + t.Fatal(err) + } + + if normResultDot != nil && base64ResultDot != nil { + if len(normResultDot.Hits) == len(base64ResultDot.Hits) { + for j := range normResultDot.Hits { + if normResultDot.Hits[j].ID != base64ResultDot.Hits[j].ID { + t.Fatalf("testcase %d failed: expected hit id %s, got hit id %s", i, normResultDot.Hits[j].ID, base64ResultDot.Hits[j].ID) + } + } + } + } else if (normResultDot == nil && base64ResultDot != nil) || + (normResultDot != nil && base64ResultDot == nil) { + t.Fatalf("testcase %d failed: expected result %s, got result %s", i, normResultDot, base64ResultDot) + } + + normResultCosine, err := indexCosine.Search(normQuery) + if err != nil { + t.Fatal(err) + } + base64ResultCosine, err := indexCosine.Search(base64Query) + if err != nil { + t.Fatal(err) + } + + if normResultCosine != nil && base64ResultCosine != nil { + if len(normResultCosine.Hits) == len(base64ResultCosine.Hits) { + for j := range normResultCosine.Hits { + if normResultCosine.Hits[j].ID != base64ResultCosine.Hits[j].ID { + t.Fatalf("testcase %d failed: expected hit id %s, got hit id %s", i, normResultCosine.Hits[j].ID, base64ResultCosine.Hits[j].ID) + } + } + } + } else if (normResultCosine == nil && base64ResultCosine != nil) || + (normResultCosine != nil && base64ResultCosine == nil) { + t.Fatalf("testcase %d failed: expected result %s, got result %s", i, normResultCosine, base64ResultCosine) + } + + if normResultCosine != nil && normResultL2 != nil { + if len(normResultCosine.Hits) == len(normResultL2.Hits) { + for j := range normResultCosine.Hits { + if normResultCosine.Hits[j].ID != normResultL2.Hits[j].ID { + if normResultCosine.Hits[j].Score != normResultL2.Hits[j].Score { + t.Fatalf("testcase %d failed: expected hit id %s, got hit id %s", i, normResultCosine.Hits[j].ID, normResultL2.Hits[j].ID) + } + } + } + } + } else if (normResultCosine == nil && normResultL2 != nil) || + (normResultCosine != nil && normResultL2 == nil) { + t.Fatalf("testcase %d failed: expected result %s, got result %s", i, normResultCosine, normResultL2) + } + + if normResultCosine != nil && normResultDot != nil { + if len(normResultCosine.Hits) == len(normResultDot.Hits) { + for j := range normResultCosine.Hits { + if normResultCosine.Hits[j].ID != normResultDot.Hits[j].ID { + if normResultCosine.Hits[j].Score != normResultDot.Hits[j].Score { + t.Fatalf("testcase %d failed: expected hit id %s, got hit id %s", i, normResultCosine.Hits[j].ID, normResultDot.Hits[j].ID) + } + } + } + } + } else if (normResultCosine == nil && normResultDot != nil) || + (normResultCosine != nil && normResultDot == nil) { + t.Fatalf("testcase %d failed: expected result %s, got result %s", i, normResultCosine, normResultDot) + } + } + } +} + type testDocument struct { ID string `json:"id"` Content string `json:"content"`