diff --git a/api_dql.go b/api_dql.go index 9213784..654fed4 100644 --- a/api_dql.go +++ b/api_dql.go @@ -1,6 +1,8 @@ package modusdb -import "fmt" +import ( + "fmt" +) type QueryFunc func() string @@ -31,6 +33,7 @@ const ( funcUid = `func: uid(%d)` funcEq = `func: eq(%s, %s)` + // funcSimilarTo = `func: similar_to(%s, %d, "%s")` ) func buildUidQuery(gid uint64) QueryFunc { @@ -39,12 +42,24 @@ func buildUidQuery(gid uint64) QueryFunc { } } -func buildEqQuery(key, value any) QueryFunc { +func buildEqQuery(key string, value any) QueryFunc { return func() string { return fmt.Sprintf(funcEq, key, value) } } +// func buildVecSearchQuery(indexAttr string, topK int64, vec []float32) QueryFunc { +// vecStrArr := make([]string, len(vec)) +// for i := range vec { +// vecStrArr[i] = strconv.FormatFloat(float64(vec[i]), 'f', -1, 32) +// } +// vecStr := strings.Join(vecStrArr, ",") +// vecStr = "[" + vecStr + "]" +// return func() string { +// return fmt.Sprintf(funcSimilarTo, indexAttr, topK, vecStr) +// } +// } + func formatObjQuery(qf QueryFunc, extraFields string) string { return fmt.Sprintf(objQuery, qf(), extraFields) } diff --git a/api_mutate_helper.go b/api_mutate_helper.go index e997258..64fd73d 100644 --- a/api_mutate_helper.go +++ b/api_mutate_helper.go @@ -90,15 +90,13 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac } if jsonToDbTags[jsonName] != nil { constraint := jsonToDbTags[jsonName].constraint - if constraint == "unique" || constraint == "term" { - uniqueConstraintFound = true - u.Directive = pb.SchemaUpdate_INDEX - if constraint == "unique" { - u.Tokenizer = []string{"exact"} - } else { - u.Tokenizer = []string{"term"} - } + if constraint == "vector" && valType != pb.Posting_VFLOAT { + return fmt.Errorf("vector index can only be applied to []float values") + } + if uniqueConstraintFound { + } + uniqueConstraintFound = addIndex(u, constraint, uniqueConstraintFound) } sch.Preds = append(sch.Preds, u) @@ -162,13 +160,10 @@ func generateCreateDqlMutationsAndSchemaFromRaw(n *Namespace, data map[string]an ValueType: valType, } if indexes[pred] != "" { - u.Directive = pb.SchemaUpdate_INDEX - index := indexes[pred] - if index == "unique" { - u.Tokenizer = []string{"exact"} - } else if index == "term" { - u.Tokenizer = []string{"term"} + if indexes[pred] == "vector" && valType != pb.Posting_VFLOAT { + return fmt.Errorf("vector index can only be applied to []float values") } + addIndex(u, indexes[pred], false) } sch.Preds = append(sch.Preds, u) @@ -289,3 +284,30 @@ func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) return gid, nil } + +func addIndex(u *pb.SchemaUpdate, index string, uniqueConstraintExists bool) bool { + u.Directive = pb.SchemaUpdate_INDEX + switch index { + case "unique": + u.Tokenizer = []string{"exact"} + uniqueConstraintExists = true + case "term": + u.Tokenizer = []string{"term"} + uniqueConstraintExists = true + case "vector": + u.IndexSpecs = []*pb.VectorIndexSpec{ + { + Name: "hnsw", + Options: []*pb.OptionPair{ + { + Key: "metric", + Value: "cosine", + }, + }, + }, + } + default: + return uniqueConstraintExists + } + return uniqueConstraintExists +} diff --git a/api_query_helper.go b/api_query_helper.go index e78200f..d39e274 100644 --- a/api_query_helper.go +++ b/api_query_helper.go @@ -69,7 +69,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac return 0, nil, fmt.Errorf("invalid unique field type") } - if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" { + if jsonToDbTag[cf.Key] == nil || jsonToDbTag[cf.Key].constraint == "" { return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key) } diff --git a/api_test.go b/api_test.go index 9231695..3e659a5 100644 --- a/api_test.go +++ b/api_test.go @@ -553,7 +553,7 @@ func TestRawAPIs(t *testing.T) { require.Equal(t, gid, getGid) require.Equal(t, "B", maps["name"]) - //TODO figure out why it comes back as a flaot64 + //TODO figure out why it comes back as a float64 // schema and value are correctly set to int, so its a query side issue require.Equal(t, float64(20), maps["age"]) require.Equal(t, "123", maps["clerk_id"]) @@ -565,3 +565,147 @@ func TestRawAPIs(t *testing.T) { require.NoError(t, err) require.Equal(t, gid, deleteGid) } + +func TestVectorIndexInsert(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + _, err = modusdb.RawCreate(db, map[string]any{ + "name": "B", + "age": 20, + "clerk_id": "123", + "vec": []float64{1.0, 2.0, 3.0}, + }, map[string]string{ + "clerk_id": "unique", + "vec": "vector", + }, db1.ID()) + + require.NoError(t, err) + + query := `{ + me(func: has(name)) { + uid + name + age + clerk_id + vec + } + }` + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, `{"me":[{"uid":"0x2","name":"B","age":20,"clerk_id":"123","vec":[1,2,3]}]}`, + string(resp.GetJson())) + +} + +func TestVectorIndexSearchUntyped(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + vectors := [][]float64{ + {1.0, 2.0, 3.0}, // Sequential + {4.0, 5.0, 6.0}, // Sequential continued + {7.0, 8.0, 9.0}, // Sequential continued + {0.1, 0.2, 0.3}, // Small decimals + {1.5, 2.5, 3.5}, // Half steps + {1.0, 2.0, 3.0}, // Duplicate + {10.0, 20.0, 30.0}, // Tens + {0.5, 1.0, 1.5}, // Half increments + {2.2, 4.4, 6.6}, // Multiples of 2.2 + {3.3, 6.6, 9.9}, // Multiples of 3.3 + } + + for _, vec := range vectors { + _, err = modusdb.RawCreate(db, map[string]any{ + "vec": vec, + }, map[string]string{ + "vec": "vector", + }, db1.ID()) + require.NoError(t, err) + } + + const query = ` + { + vector(func: similar_to(vec, 5, "[4.1,5.1,6.1]")) { + vec + } + }` + + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, `{ + "vector":[ + {"vec":[4,5,6]}, + {"vec":[7,8,9]}, + {"vec":[1.5,2.5,3.5]}, + {"vec":[10,20,30]}, + {"vec":[2.2,4.4,6.6]} + ] + }`, string(resp.GetJson())) +} + +type Document struct { + Gid uint64 `json:"gid,omitempty"` + Text string `json:"text,omitempty"` + TextVec []float32 `json:"textVec,omitempty" db:"constraint=vector"` +} + +func TestVectorIndexSearchTyped(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + documents := []Document{ + {Text: "apple", TextVec: []float32{1.0, 0.0, 0.0}}, + {Text: "banana", TextVec: []float32{0.0, 1.0, 0.0}}, + {Text: "carrot", TextVec: []float32{0.0, 0.0, 1.0}}, + {Text: "dog", TextVec: []float32{1.0, 1.0, 0.0}}, + {Text: "elephant", TextVec: []float32{0.0, 1.0, 1.0}}, + {Text: "fox", TextVec: []float32{1.0, 0.0, 1.0}}, + {Text: "gorilla", TextVec: []float32{1.0, 1.0, 1.0}}, + } + + for _, doc := range documents { + _, _, err = modusdb.Create(db, &doc, db1.ID()) + require.NoError(t, err) + } + + const query = ` + { + documents(func: similar_to(Document.textVec, 5, "[0.1,0.1,0.1]")) { + Document.text + } + }` + + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, `{ + "documents":[ + {"Document.text":"apple"}, + {"Document.text":"dog"}, + {"Document.text":"elephant"}, + {"Document.text":"fox"}, + {"Document.text":"gorilla"} + ] + }`, string(resp.GetJson())) +} diff --git a/api_types.go b/api_types.go index 860edda..a901423 100644 --- a/api_types.go +++ b/api_types.go @@ -8,6 +8,7 @@ import ( "github.com/dgraph-io/dgo/v240/protos/api" "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/types" "github.com/dgraph-io/dgraph/v24/x" "github.com/twpayne/go-geom" "github.com/twpayne/go-geom/encoding/wkb" @@ -110,6 +111,16 @@ func valueToApiVal(v any) (*api.Value, error) { return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: float64(val)}}, nil case float64: return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}}, nil + case []float32: + return &api.Value{Val: &api.Value_Vfloat32Val{ + Vfloat32Val: types.FloatArrayAsBytes(val)}}, nil + case []float64: + float32Slice := make([]float32, len(val)) + for i, v := range val { + float32Slice[i] = float32(v) + } + return &api.Value{Val: &api.Value_Vfloat32Val{ + Vfloat32Val: types.FloatArrayAsBytes(float32Slice)}}, nil case []byte: return &api.Value{Val: &api.Value_BytesVal{BytesVal: val}}, nil case time.Time: