From 6ed70262a6cbdcea2f129a75d6ddfba41977b944 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Sun, 29 Dec 2024 10:51:48 -0800 Subject: [PATCH 1/2] add vector index support --- api_mutate_helper.go | 37 +++++++++++-------- api_test.go | 86 +++++++++++++++++++++++++++++++++++++++++++- api_types.go | 11 ++++++ 3 files changed, 119 insertions(+), 15 deletions(-) diff --git a/api_mutate_helper.go b/api_mutate_helper.go index e997258..9f7fb62 100644 --- a/api_mutate_helper.go +++ b/api_mutate_helper.go @@ -90,15 +90,10 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac } if jsonToDbTags[jsonName] != nil { constraint := jsonToDbTags[jsonName].constraint - if constraint == "unique" || constraint == "term" { - uniqueConstraintFound = true - u.Directive = pb.SchemaUpdate_INDEX - if constraint == "unique" { - u.Tokenizer = []string{"exact"} - } else { - u.Tokenizer = []string{"term"} - } + if constraint == "vector" && valType != pb.Posting_VFLOAT { + return fmt.Errorf("vector index can only be applied to []float values") } + uniqueConstraintFound = addIndex(u, constraint) } sch.Preds = append(sch.Preds, u) @@ -162,13 +157,10 @@ func generateCreateDqlMutationsAndSchemaFromRaw(n *Namespace, data map[string]an ValueType: valType, } if indexes[pred] != "" { - u.Directive = pb.SchemaUpdate_INDEX - index := indexes[pred] - if index == "unique" { - u.Tokenizer = []string{"exact"} - } else if index == "term" { - u.Tokenizer = []string{"term"} + if indexes[pred] == "vector" && valType != pb.Posting_VFLOAT { + return fmt.Errorf("vector index can only be applied to []float values") } + addIndex(u, indexes[pred]) } sch.Preds = append(sch.Preds, u) @@ -289,3 +281,20 @@ func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) return gid, nil } + +func addIndex(u *pb.SchemaUpdate, index string) bool { + u.Directive = pb.SchemaUpdate_INDEX + switch index { + case "unique": + u.Tokenizer = []string{"exact"} + return true + case "term": + u.Tokenizer = []string{"term"} + return true + case "vector": + u.Tokenizer = []string{"hnsw"} + default: + return false + } + return false +} diff --git a/api_test.go b/api_test.go index 9231695..bffefa3 100644 --- a/api_test.go +++ b/api_test.go @@ -553,7 +553,7 @@ func TestRawAPIs(t *testing.T) { require.Equal(t, gid, getGid) require.Equal(t, "B", maps["name"]) - //TODO figure out why it comes back as a flaot64 + //TODO figure out why it comes back as a float64 // schema and value are correctly set to int, so its a query side issue require.Equal(t, float64(20), maps["age"]) require.Equal(t, "123", maps["clerk_id"]) @@ -565,3 +565,87 @@ func TestRawAPIs(t *testing.T) { require.NoError(t, err) require.Equal(t, gid, deleteGid) } + +func TestVectorIndexInsert(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + _, err = modusdb.RawCreate(db, map[string]any{ + "name": "B", + "age": 20, + "clerk_id": "123", + "vec": []float64{1.0, 2.0, 3.0}, + }, map[string]string{ + "clerk_id": "unique", + "vec": "vector", + }, db1.ID()) + + require.NoError(t, err) + + query := `{ + me(func: has(name)) { + uid + name + age + clerk_id + vec + } + }` + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, `{"me":[{"uid":"0x2","name":"B","age":20,"clerk_id":"123","vec":[1,2,3]}]}`, + string(resp.GetJson())) + +} + +func TestVectorIndexSearch(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + vectors := [][]float64{ + {1.0, 2.0, 3.0}, // Sequential + {4.0, 5.0, 6.0}, // Sequential continued + {7.0, 8.0, 9.0}, // Sequential continued + {0.1, 0.2, 0.3}, // Small decimals + {1.5, 2.5, 3.5}, // Half steps + {1.0, 2.0, 3.0}, // Duplicate + {10.0, 20.0, 30.0}, // Tens + {0.5, 1.0, 1.5}, // Half increments + {2.2, 4.4, 6.6}, // Multiples of 2.2 + {3.3, 6.6, 9.9}, // Multiples of 3.3 + } + + for _, vec := range vectors { + _, err = modusdb.RawCreate(db, map[string]any{ + "vec": vec, + }, map[string]string{ + "vec": "vector", + }, db1.ID()) + require.NoError(t, err) + } + + const query = ` + { + vector(func: similar_to(vtest, 1, "[4.1,5.1,6.1]")) { + vtest + } + }` + + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, `{"vector":[{"vtest":[4,5,6]}]}`, string(resp.GetJson())) +} diff --git a/api_types.go b/api_types.go index 860edda..a901423 100644 --- a/api_types.go +++ b/api_types.go @@ -8,6 +8,7 @@ import ( "github.com/dgraph-io/dgo/v240/protos/api" "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/types" "github.com/dgraph-io/dgraph/v24/x" "github.com/twpayne/go-geom" "github.com/twpayne/go-geom/encoding/wkb" @@ -110,6 +111,16 @@ func valueToApiVal(v any) (*api.Value, error) { return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: float64(val)}}, nil case float64: return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}}, nil + case []float32: + return &api.Value{Val: &api.Value_Vfloat32Val{ + Vfloat32Val: types.FloatArrayAsBytes(val)}}, nil + case []float64: + float32Slice := make([]float32, len(val)) + for i, v := range val { + float32Slice[i] = float32(v) + } + return &api.Value{Val: &api.Value_Vfloat32Val{ + Vfloat32Val: types.FloatArrayAsBytes(float32Slice)}}, nil case []byte: return &api.Value{Val: &api.Value_BytesVal{BytesVal: val}}, nil case time.Time: From 851a45dc253b616dcae9f05355e3cdf55c390e3e Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Mon, 30 Dec 2024 23:57:01 -0800 Subject: [PATCH 2/2] fixes --- api_dql.go | 19 +++++++++++-- api_mutate_helper.go | 29 +++++++++++++------ api_query_helper.go | 2 +- api_test.go | 68 +++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 103 insertions(+), 15 deletions(-) diff --git a/api_dql.go b/api_dql.go index 9213784..654fed4 100644 --- a/api_dql.go +++ b/api_dql.go @@ -1,6 +1,8 @@ package modusdb -import "fmt" +import ( + "fmt" +) type QueryFunc func() string @@ -31,6 +33,7 @@ const ( funcUid = `func: uid(%d)` funcEq = `func: eq(%s, %s)` + // funcSimilarTo = `func: similar_to(%s, %d, "%s")` ) func buildUidQuery(gid uint64) QueryFunc { @@ -39,12 +42,24 @@ func buildUidQuery(gid uint64) QueryFunc { } } -func buildEqQuery(key, value any) QueryFunc { +func buildEqQuery(key string, value any) QueryFunc { return func() string { return fmt.Sprintf(funcEq, key, value) } } +// func buildVecSearchQuery(indexAttr string, topK int64, vec []float32) QueryFunc { +// vecStrArr := make([]string, len(vec)) +// for i := range vec { +// vecStrArr[i] = strconv.FormatFloat(float64(vec[i]), 'f', -1, 32) +// } +// vecStr := strings.Join(vecStrArr, ",") +// vecStr = "[" + vecStr + "]" +// return func() string { +// return fmt.Sprintf(funcSimilarTo, indexAttr, topK, vecStr) +// } +// } + func formatObjQuery(qf QueryFunc, extraFields string) string { return fmt.Sprintf(objQuery, qf(), extraFields) } diff --git a/api_mutate_helper.go b/api_mutate_helper.go index 9f7fb62..64fd73d 100644 --- a/api_mutate_helper.go +++ b/api_mutate_helper.go @@ -93,7 +93,10 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac if constraint == "vector" && valType != pb.Posting_VFLOAT { return fmt.Errorf("vector index can only be applied to []float values") } - uniqueConstraintFound = addIndex(u, constraint) + if uniqueConstraintFound { + + } + uniqueConstraintFound = addIndex(u, constraint, uniqueConstraintFound) } sch.Preds = append(sch.Preds, u) @@ -160,7 +163,7 @@ func generateCreateDqlMutationsAndSchemaFromRaw(n *Namespace, data map[string]an if indexes[pred] == "vector" && valType != pb.Posting_VFLOAT { return fmt.Errorf("vector index can only be applied to []float values") } - addIndex(u, indexes[pred]) + addIndex(u, indexes[pred], false) } sch.Preds = append(sch.Preds, u) @@ -282,19 +285,29 @@ func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) return gid, nil } -func addIndex(u *pb.SchemaUpdate, index string) bool { +func addIndex(u *pb.SchemaUpdate, index string, uniqueConstraintExists bool) bool { u.Directive = pb.SchemaUpdate_INDEX switch index { case "unique": u.Tokenizer = []string{"exact"} - return true + uniqueConstraintExists = true case "term": u.Tokenizer = []string{"term"} - return true + uniqueConstraintExists = true case "vector": - u.Tokenizer = []string{"hnsw"} + u.IndexSpecs = []*pb.VectorIndexSpec{ + { + Name: "hnsw", + Options: []*pb.OptionPair{ + { + Key: "metric", + Value: "cosine", + }, + }, + }, + } default: - return false + return uniqueConstraintExists } - return false + return uniqueConstraintExists } diff --git a/api_query_helper.go b/api_query_helper.go index e78200f..d39e274 100644 --- a/api_query_helper.go +++ b/api_query_helper.go @@ -69,7 +69,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac return 0, nil, fmt.Errorf("invalid unique field type") } - if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" { + if jsonToDbTag[cf.Key] == nil || jsonToDbTag[cf.Key].constraint == "" { return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key) } diff --git a/api_test.go b/api_test.go index bffefa3..3e659a5 100644 --- a/api_test.go +++ b/api_test.go @@ -605,7 +605,7 @@ func TestVectorIndexInsert(t *testing.T) { } -func TestVectorIndexSearch(t *testing.T) { +func TestVectorIndexSearchUntyped(t *testing.T) { ctx := context.Background() db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) require.NoError(t, err) @@ -640,12 +640,72 @@ func TestVectorIndexSearch(t *testing.T) { const query = ` { - vector(func: similar_to(vtest, 1, "[4.1,5.1,6.1]")) { - vtest + vector(func: similar_to(vec, 5, "[4.1,5.1,6.1]")) { + vec } }` resp, err := db1.Query(ctx, query) require.NoError(t, err) - require.JSONEq(t, `{"vector":[{"vtest":[4,5,6]}]}`, string(resp.GetJson())) + require.JSONEq(t, `{ + "vector":[ + {"vec":[4,5,6]}, + {"vec":[7,8,9]}, + {"vec":[1.5,2.5,3.5]}, + {"vec":[10,20,30]}, + {"vec":[2.2,4.4,6.6]} + ] + }`, string(resp.GetJson())) +} + +type Document struct { + Gid uint64 `json:"gid,omitempty"` + Text string `json:"text,omitempty"` + TextVec []float32 `json:"textVec,omitempty" db:"constraint=vector"` +} + +func TestVectorIndexSearchTyped(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + documents := []Document{ + {Text: "apple", TextVec: []float32{1.0, 0.0, 0.0}}, + {Text: "banana", TextVec: []float32{0.0, 1.0, 0.0}}, + {Text: "carrot", TextVec: []float32{0.0, 0.0, 1.0}}, + {Text: "dog", TextVec: []float32{1.0, 1.0, 0.0}}, + {Text: "elephant", TextVec: []float32{0.0, 1.0, 1.0}}, + {Text: "fox", TextVec: []float32{1.0, 0.0, 1.0}}, + {Text: "gorilla", TextVec: []float32{1.0, 1.0, 1.0}}, + } + + for _, doc := range documents { + _, _, err = modusdb.Create(db, &doc, db1.ID()) + require.NoError(t, err) + } + + const query = ` + { + documents(func: similar_to(Document.textVec, 5, "[0.1,0.1,0.1]")) { + Document.text + } + }` + + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, `{ + "documents":[ + {"Document.text":"apple"}, + {"Document.text":"dog"}, + {"Document.text":"elephant"}, + {"Document.text":"fox"}, + {"Document.text":"gorilla"} + ] + }`, string(resp.GetJson())) }