diff --git a/api.go b/api.go index 7fcdccd..e329198 100644 --- a/api.go +++ b/api.go @@ -7,7 +7,7 @@ import ( "github.com/dgraph-io/dgraph/v24/schema" ) -func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) { +func Create[T any](db *DB, object T, ns ...uint64) (uint64, T, error) { db.mutex.Lock() defer db.mutex.Unlock() if len(ns) > 1 { @@ -25,7 +25,7 @@ func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) { dms := make([]*dql.Mutation, 0) sch := &schema.ParsedSchema{} - err = generateCreateDqlMutationsAndSchema[T](ctx, n, *object, gid, &dms, sch) + err = generateCreateDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) if err != nil { return 0, object, err } @@ -43,7 +43,7 @@ func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) { return getByGid[T](ctx, n, gid) } -func Upsert[T any](db *DB, object *T, ns ...uint64) (uint64, *T, bool, error) { +func Upsert[T any](db *DB, object T, ns ...uint64) (uint64, T, bool, error) { var wasFound bool db.mutex.Lock() @@ -51,80 +51,78 @@ func Upsert[T any](db *DB, object *T, ns ...uint64) (uint64, *T, bool, error) { if len(ns) > 1 { return 0, object, false, fmt.Errorf("only one namespace is allowed") } - if object == nil { - return 0, nil, false, fmt.Errorf("object is nil") - } ctx, n, err := getDefaultNamespace(db, ns...) if err != nil { return 0, object, false, err } - gid, cf, err := getUniqueConstraint[T](*object) + gid, cf, err := getUniqueConstraint[T](object) if err != nil { - return 0, nil, false, err + return 0, object, false, err } dms := make([]*dql.Mutation, 0) sch := &schema.ParsedSchema{} - err = generateCreateDqlMutationsAndSchema[T](ctx, n, *object, gid, &dms, sch) + err = generateCreateDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) if err != nil { - return 0, nil, false, err + return 0, object, false, err } err = n.alterSchemaWithParsed(ctx, sch) if err != nil { - return 0, nil, false, err + return 0, object, false, err } if gid != 0 { - gid, _, err = getByGidWithObject[T](ctx, n, gid, *object) + gid, _, err = getByGidWithObject[T](ctx, n, gid, object) if err != nil && err != ErrNoObjFound { - return 0, nil, false, err + return 0, object, false, err } wasFound = err == nil } else if cf != nil { - gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, *object) + gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) if err != nil && err != ErrNoObjFound { - return 0, nil, false, err + return 0, object, false, err } wasFound = err == nil } if gid == 0 { gid, err = db.z.nextUID() if err != nil { - return 0, nil, false, err + return 0, object, false, err } } dms = make([]*dql.Mutation, 0) - err = generateCreateDqlMutationsAndSchema[T](ctx, n, *object, gid, &dms, sch) + err = generateCreateDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) if err != nil { - return 0, nil, false, err + return 0, object, false, err } err = applyDqlMutations(ctx, db, dms) if err != nil { - return 0, nil, false, err + return 0, object, false, err } gid, object, err = getByGid[T](ctx, n, gid) if err != nil { - return 0, nil, false, err + return 0, object, false, err } return gid, object, wasFound, nil } -func Get[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, error) { +func Get[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, T, error) { db.mutex.Lock() defer db.mutex.Unlock() + var obj T if len(ns) > 1 { - return 0, nil, fmt.Errorf("only one namespace is allowed") + return 0, obj, fmt.Errorf("only one namespace is allowed") } ctx, n, err := getDefaultNamespace(db, ns...) if err != nil { - return 0, nil, err + return 0, obj, err } if uid, ok := any(uniqueField).(uint64); ok { return getByGid[T](ctx, n, uid) @@ -134,30 +132,45 @@ func Get[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, return getByConstrainedField[T](ctx, n, cf) } - return 0, nil, fmt.Errorf("invalid unique field type") + return 0, obj, fmt.Errorf("invalid unique field type") +} + +func Query[T any](db *DB, queryParams QueryParams, ns ...uint64) ([]uint64, []T, error) { + db.mutex.Lock() + defer db.mutex.Unlock() + if len(ns) > 1 { + return nil, nil, fmt.Errorf("only one namespace is allowed") + } + ctx, n, err := getDefaultNamespace(db, ns...) + if err != nil { + return nil, nil, err + } + + return executeQuery[T](ctx, n, queryParams, false) } -func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, error) { +func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, T, error) { db.mutex.Lock() defer db.mutex.Unlock() + var zeroObj T if len(ns) > 1 { - return 0, nil, fmt.Errorf("only one namespace is allowed") + return 0, zeroObj, fmt.Errorf("only one namespace is allowed") } ctx, n, err := getDefaultNamespace(db, ns...) if err != nil { - return 0, nil, err + return 0, zeroObj, err } if uid, ok := any(uniqueField).(uint64); ok { uid, obj, err := getByGid[T](ctx, n, uid) if err != nil { - return 0, nil, err + return 0, zeroObj, err } dms := generateDeleteDqlMutations(n, uid) err = applyDqlMutations(ctx, db, dms) if err != nil { - return 0, nil, err + return 0, zeroObj, err } return uid, obj, nil @@ -166,18 +179,18 @@ func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, if cf, ok := any(uniqueField).(ConstrainedField); ok { uid, obj, err := getByConstrainedField[T](ctx, n, cf) if err != nil { - return 0, nil, err + return 0, zeroObj, err } dms := generateDeleteDqlMutations(n, uid) err = applyDqlMutations(ctx, db, dms) if err != nil { - return 0, nil, err + return 0, zeroObj, err } return uid, obj, nil } - return 0, nil, fmt.Errorf("invalid unique field type") + return 0, zeroObj, fmt.Errorf("invalid unique field type") } diff --git a/api_dql.go b/api_dql.go index a2b0e65..b20ecf0 100644 --- a/api_dql.go +++ b/api_dql.go @@ -1,13 +1,17 @@ package modusdb -import "fmt" +import ( + "fmt" + "strconv" + "strings" +) type QueryFunc func() string const ( objQuery = ` { - obj(%s) { + obj(func: %s) { uid expand(_all_) { uid @@ -20,8 +24,33 @@ const ( } ` - funcUid = `func: uid(%d)` - funcEq = `func: eq(%s, %s)` + objsQuery = ` + { + objs(func: type("%s")%s) @filter(%s) { + uid + expand(_all_) { + uid + expand(_all_) + dgraph.type + } + dgraph.type + %s + } + } + ` + + funcUid = `uid(%d)` + funcEq = `eq(%s, %s)` + funcSimilarTo = `similar_to(%s, %d, "[%s]")` + funcAllOfTerms = `allofterms(%s, "%s")` + funcAnyOfTerms = `anyofterms(%s, "%s")` + funcAllOfText = `alloftext(%s, "%s")` + funcAnyOfText = `anyoftext(%s, "%s")` + funcRegExp = `regexp(%s, /%s/)` + funcLe = `le(%s, %s)` + funcGe = `ge(%s, %s)` + funcGt = `gt(%s, %s)` + funcLt = `lt(%s, %s)` ) func buildUidQuery(gid uint64) QueryFunc { @@ -30,12 +59,152 @@ func buildUidQuery(gid uint64) QueryFunc { } } -func buildEqQuery(key, value any) QueryFunc { +func buildEqQuery(key string, value any) QueryFunc { return func() string { return fmt.Sprintf(funcEq, key, value) } } +func buildSimilarToQuery(indexAttr string, topK int64, vec []float32) QueryFunc { + vecStrArr := make([]string, len(vec)) + for i := range vec { + vecStrArr[i] = strconv.FormatFloat(float64(vec[i]), 'f', -1, 32) + } + vecStr := strings.Join(vecStrArr, ",") + return func() string { + return fmt.Sprintf(funcSimilarTo, indexAttr, topK, vecStr) + } +} + +func buildAllOfTermsQuery(attr string, terms string) QueryFunc { + return func() string { + return fmt.Sprintf(funcAllOfTerms, attr, terms) + } +} + +func buildAnyOfTermsQuery(attr string, terms string) QueryFunc { + return func() string { + return fmt.Sprintf(funcAnyOfTerms, attr, terms) + } +} + +func buildAllOfTextQuery(attr, text string) QueryFunc { + return func() string { + return fmt.Sprintf(funcAllOfText, attr, text) + } +} + +func buildAnyOfTextQuery(attr, text string) QueryFunc { + return func() string { + return fmt.Sprintf(funcAnyOfText, attr, text) + } +} + +func buildRegExpQuery(attr, pattern string) QueryFunc { + return func() string { + return fmt.Sprintf(funcRegExp, attr, pattern) + } +} + +func buildLeQuery(attr, value string) QueryFunc { + return func() string { + return fmt.Sprintf(funcLe, attr, value) + } +} + +func buildGeQuery(attr, value string) QueryFunc { + return func() string { + return fmt.Sprintf(funcGe, attr, value) + } +} + +func buildGtQuery(attr, value string) QueryFunc { + return func() string { + return fmt.Sprintf(funcGt, attr, value) + } +} + +func buildLtQuery(attr, value string) QueryFunc { + return func() string { + return fmt.Sprintf(funcLt, attr, value) + } +} + +func And(qfs ...QueryFunc) QueryFunc { + return func() string { + qs := make([]string, len(qfs)) + for i, qf := range qfs { + qs[i] = qf() + } + return strings.Join(qs, " AND ") + } +} + +func Or(qfs ...QueryFunc) QueryFunc { + return func() string { + qs := make([]string, len(qfs)) + for i, qf := range qfs { + qs[i] = qf() + } + return strings.Join(qs, " OR ") + } +} + +func Not(qf QueryFunc) QueryFunc { + return func() string { + return "NOT " + qf() + } +} + func formatObjQuery(qf QueryFunc, extraFields string) string { return fmt.Sprintf(objQuery, qf(), extraFields) } + +func formatObjsQuery(typeName string, qf QueryFunc, paginationAndSorting string, extraFields string) string { + return fmt.Sprintf(objsQuery, typeName, paginationAndSorting, qf(), extraFields) +} + +// Helper function to combine multiple filters +func filtersToQueryFunc(typeName string, filter Filter) QueryFunc { + return filterToQueryFunc(typeName, filter) +} + +func paginationToQueryString(p Pagination) string { + paginationStr := "" + if p.Limit > 0 { + paginationStr += ", " + fmt.Sprintf("first: %d", p.Limit) + } + if p.Offset > 0 { + paginationStr += ", " + fmt.Sprintf("offset: %d", p.Offset) + } else if p.After != "" { + paginationStr += ", " + fmt.Sprintf("after: %s", p.After) + } + if paginationStr == "" { + return "" + } + return paginationStr +} + +func sortingToQueryString(typeName string, s Sorting) string { + if s.OrderAscField == "" && s.OrderDescField == "" { + return "" + } + + var parts []string + first, second := s.OrderDescField, s.OrderAscField + firstOp, secondOp := "orderdesc", "orderasc" + + if !s.OrderDescFirst { + first, second = s.OrderAscField, s.OrderDescField + firstOp, secondOp = "orderasc", "orderdesc" + } + + if first != "" { + parts = append(parts, fmt.Sprintf("%s: %s", firstOp, getPredicateName(typeName, first))) + } + if second != "" { + parts = append(parts, fmt.Sprintf("%s: %s", secondOp, getPredicateName(typeName, second))) + } + + return ", " + strings.Join(parts, ", ") +} diff --git a/api_mutate_helper.go b/api_mutate_helper.go index 7957e0a..3d381c1 100644 --- a/api_mutate_helper.go +++ b/api_mutate_helper.go @@ -90,15 +90,10 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac } if jsonToDbTags[jsonName] != nil { constraint := jsonToDbTags[jsonName].constraint - if constraint == "unique" || constraint == "term" { - uniqueConstraintFound = true - u.Directive = pb.SchemaUpdate_INDEX - if constraint == "unique" { - u.Tokenizer = []string{"exact"} - } else { - u.Tokenizer = []string{"term"} - } + if constraint == "vector" && valType != pb.Posting_VFLOAT { + return fmt.Errorf("vector index can only be applied to []float values") } + uniqueConstraintFound = addIndex(u, constraint, uniqueConstraintFound) } sch.Preds = append(sch.Preds, u) @@ -238,3 +233,39 @@ func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) return gid, nil } + +func addIndex(u *pb.SchemaUpdate, index string, uniqueConstraintExists bool) bool { + u.Directive = pb.SchemaUpdate_INDEX + switch index { + case "exact": + u.Tokenizer = []string{"exact"} + case "term": + u.Tokenizer = []string{"term"} + case "hash": + u.Tokenizer = []string{"hash"} + case "unique": + u.Tokenizer = []string{"exact"} + u.Unique = true + u.Upsert = true + uniqueConstraintExists = true + case "fulltext": + u.Tokenizer = []string{"fulltext"} + case "trigram": + u.Tokenizer = []string{"trigram"} + case "vector": + u.IndexSpecs = []*pb.VectorIndexSpec{ + { + Name: "hnsw", + Options: []*pb.OptionPair{ + { + Key: "metric", + Value: "cosine", + }, + }, + }, + } + default: + return uniqueConstraintExists + } + return uniqueConstraintExists +} diff --git a/api_query_helper.go b/api_query_helper.go index e62d8f3..a277d79 100644 --- a/api_query_helper.go +++ b/api_query_helper.go @@ -7,41 +7,40 @@ import ( "reflect" ) -func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, error) { +func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, T, error) { return executeGet[T](ctx, n, gid) } -func getByGidWithObject[T any](ctx context.Context, n *Namespace, gid uint64, obj T) (uint64, *T, error) { +func getByGidWithObject[T any](ctx context.Context, n *Namespace, gid uint64, obj T) (uint64, T, error) { return executeGetWithObject[T](ctx, n, obj, false, gid) } -func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) { +func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, T, error) { return executeGet[T](ctx, n, cf) } func getByConstrainedFieldWithObject[T any](ctx context.Context, n *Namespace, - cf ConstrainedField, obj T) (uint64, *T, error) { + cf ConstrainedField, obj T) (uint64, T, error) { return executeGetWithObject[T](ctx, n, obj, false, cf) } -func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, args ...R) (uint64, *T, error) { +func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, args ...R) (uint64, T, error) { + var obj T if len(args) != 1 { - return 0, nil, fmt.Errorf("expected 1 argument, got %d", len(args)) + return 0, obj, fmt.Errorf("expected 1 argument, got %d", len(args)) } - var obj T - return executeGetWithObject(ctx, n, obj, true, args...) } func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespace, - obj T, withReverse bool, args ...R) (uint64, *T, error) { + obj T, withReverse bool, args ...R) (uint64, T, error) { t := reflect.TypeOf(obj) fieldToJsonTags, jsonToDbTag, jsonToReverseEdgeTags, err := getFieldTags(t) if err != nil { - return 0, nil, err + return 0, obj, err } readFromQuery := "" if withReverse { @@ -64,16 +63,16 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac } else if cf, ok = any(args[0]).(ConstrainedField); ok { query = formatObjQuery(buildEqQuery(getPredicateName(t.Name(), cf.Key), cf.Value), readFromQuery) } else { - return 0, nil, fmt.Errorf("invalid unique field type") + return 0, obj, fmt.Errorf("invalid unique field type") } if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" { - return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key) + return 0, obj, fmt.Errorf("constraint not defined for field %s", cf.Key) } resp, err := n.queryWithLock(ctx, query) if err != nil { - return 0, nil, err + return 0, obj, err } dynamicType := createDynamicStruct(t, fieldToJsonTags, 1) @@ -88,36 +87,122 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac // Unmarshal the JSON response into the dynamic struct if err := json.Unmarshal(resp.Json, &result); err != nil { - return 0, nil, err + return 0, obj, err } // Check if we have at least one object in the response if len(result.Obj) == 0 { - return 0, nil, ErrNoObjFound + return 0, obj, ErrNoObjFound } // Map the dynamic struct to the final type T finalObject := reflect.New(t).Interface() gid, err = mapDynamicToFinal(result.Obj[0], finalObject) if err != nil { - return 0, nil, err + return 0, obj, err + } + + if typedPtr, ok := finalObject.(*T); ok { + return gid, *typedPtr, nil + } + + if dirType, ok := finalObject.(T); ok { + return gid, dirType, nil } - // Convert to *interface{} then to *T - if ifacePtr, ok := finalObject.(*interface{}); ok { - if typedPtr, ok := (*ifacePtr).(*T); ok { - return gid, typedPtr, nil + return 0, obj, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) +} + +func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryParams, + withReverse bool) ([]uint64, []T, error) { + var obj T + t := reflect.TypeOf(obj) + fieldToJsonTags, _, jsonToReverseEdgeTags, err := getFieldTags(t) + if err != nil { + return nil, nil, err + } + + var filterQueryFunc QueryFunc = func() string { + return "" + } + var paginationAndSorting string + if queryParams.Filter != nil { + filterQueryFunc = filtersToQueryFunc(t.Name(), *queryParams.Filter) + } + if queryParams.Pagination != nil || queryParams.Sorting != nil { + var pagination, sorting string + if queryParams.Pagination != nil { + pagination = paginationToQueryString(*queryParams.Pagination) } + if queryParams.Sorting != nil { + sorting = sortingToQueryString(t.Name(), *queryParams.Sorting) + } + paginationAndSorting = fmt.Sprintf("%s %s", pagination, sorting) } - // If conversion fails, try direct conversion - if typedPtr, ok := finalObject.(*T); ok { - return gid, typedPtr, nil + readFromQuery := "" + if withReverse { + for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { + readFromQuery += fmt.Sprintf(` + %s: ~%s { + uid + expand(_all_) + dgraph.type + } + `, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + } } - if dirType, ok := finalObject.(T); ok { - return gid, &dirType, nil + query := formatObjsQuery(t.Name(), filterQueryFunc, paginationAndSorting, readFromQuery) + + resp, err := n.queryWithLock(ctx, query) + if err != nil { + return nil, nil, err + } + + dynamicType := createDynamicStruct(t, fieldToJsonTags, 1) + + var result struct { + Objs []any `json:"objs"` + } + + var tempMap map[string][]any + if err := json.Unmarshal(resp.Json, &tempMap); err != nil { + return nil, nil, err + } + + // Determine the number of elements + numElements := len(tempMap["objs"]) + + // Append the interface the correct number of times + for i := 0; i < numElements; i++ { + result.Objs = append(result.Objs, reflect.New(dynamicType).Interface()) + } + + // Unmarshal the JSON response into the dynamic struct + if err := json.Unmarshal(resp.Json, &result); err != nil { + return nil, nil, err + } + + var gids []uint64 + var objs []T + for _, obj := range result.Objs { + finalObject := reflect.New(t).Interface() + gid, err := mapDynamicToFinal(obj, finalObject) + if err != nil { + return nil, nil, err + } + + if typedPtr, ok := finalObject.(*T); ok { + gids = append(gids, gid) + objs = append(objs, *typedPtr) + } else if dirType, ok := finalObject.(T); ok { + gids = append(gids, gid) + objs = append(objs, dirType) + } else { + return nil, nil, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) + } } - return 0, nil, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) + return gids, objs, nil } diff --git a/api_reflect.go b/api_reflect.go index f74cd53..832358e 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -183,7 +183,7 @@ func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { return gid, nil, nil } } - if jsonToDbTags[jsonName] != nil && jsonToDbTags[jsonName].constraint == "unique" { + if jsonToDbTags[jsonName] != nil && isValidUniqueIndex(jsonToDbTags[jsonName].constraint) { // check if value is zero or nil if value == reflect.Zero(reflect.TypeOf(value)).Interface() || value == nil { continue @@ -197,3 +197,7 @@ func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { return 0, nil, fmt.Errorf(NoUniqueConstr, t.Name()) } + +func isValidUniqueIndex(name string) bool { + return name == "unique" +} diff --git a/api_test.go b/api_test.go index 1d3293c..701b30c 100644 --- a/api_test.go +++ b/api_test.go @@ -22,7 +22,7 @@ func TestFirstTimeUser(t *testing.T) { require.NoError(t, err) defer db.Close() - gid, user, err := modusdb.Create(db, &User{ + gid, user, err := modusdb.Create(db, User{ Name: "A", Age: 10, ClerkId: "123", @@ -59,7 +59,7 @@ func TestFirstTimeUser(t *testing.T) { _, queriedUser3, err := modusdb.Get[User](db, gid) require.Error(t, err) require.Equal(t, "no object found", err.Error()) - require.Nil(t, queriedUser3) + require.Equal(t, queriedUser3, User{}) } @@ -74,7 +74,7 @@ func TestCreateApi(t *testing.T) { require.NoError(t, db1.DropData(ctx)) - user := &User{ + user := User{ Name: "B", Age: 20, ClerkId: "123", @@ -131,7 +131,7 @@ func TestCreateApiWithNonStruct(t *testing.T) { require.NoError(t, db1.DropData(ctx)) - user := &User{ + user := User{ Name: "B", Age: 20, } @@ -152,7 +152,7 @@ func TestGetApi(t *testing.T) { require.NoError(t, db1.DropData(ctx)) - user := &User{ + user := User{ Name: "B", Age: 20, ClerkId: "123", @@ -181,7 +181,7 @@ func TestGetApiWithConstrainedField(t *testing.T) { require.NoError(t, db1.DropData(ctx)) - user := &User{ + user := User{ Name: "B", Age: 20, ClerkId: "123", @@ -213,7 +213,7 @@ func TestDeleteApi(t *testing.T) { require.NoError(t, db1.DropData(ctx)) - user := &User{ + user := User{ Name: "B", Age: 20, ClerkId: "123", @@ -228,7 +228,7 @@ func TestDeleteApi(t *testing.T) { _, queriedUser, err := modusdb.Get[User](db, gid, db1.ID()) require.Error(t, err) require.Equal(t, "no object found", err.Error()) - require.Nil(t, queriedUser) + require.Equal(t, queriedUser, User{}) _, queriedUser, err = modusdb.Get[User](db, modusdb.ConstrainedField{ Key: "clerk_id", @@ -236,7 +236,7 @@ func TestDeleteApi(t *testing.T) { }, db1.ID()) require.Error(t, err) require.Equal(t, "no object found", err.Error()) - require.Nil(t, queriedUser) + require.Equal(t, queriedUser, User{}) } func TestUpsertApi(t *testing.T) { @@ -250,7 +250,7 @@ func TestUpsertApi(t *testing.T) { require.NoError(t, db1.DropData(ctx)) - user := &User{ + user := User{ Name: "B", Age: 20, ClerkId: "123", @@ -273,6 +273,123 @@ func TestUpsertApi(t *testing.T) { require.Equal(t, "123", queriedUser.ClerkId) } +func TestQueryApi(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + users := []User{ + {Name: "A", Age: 10, ClerkId: "123"}, + {Name: "B", Age: 20, ClerkId: "123"}, + {Name: "C", Age: 30, ClerkId: "123"}, + {Name: "D", Age: 40, ClerkId: "123"}, + {Name: "E", Age: 50, ClerkId: "123"}, + } + + for _, user := range users { + _, _, err = modusdb.Create(db, user, db1.ID()) + require.NoError(t, err) + } + + gids, queriedUsers, err := modusdb.Query[User](db, modusdb.QueryParams{}, db1.ID()) + require.NoError(t, err) + require.Len(t, queriedUsers, 5) + require.Len(t, gids, 5) + require.Equal(t, "A", queriedUsers[0].Name) + require.Equal(t, "B", queriedUsers[1].Name) + require.Equal(t, "C", queriedUsers[2].Name) + require.Equal(t, "D", queriedUsers[3].Name) + require.Equal(t, "E", queriedUsers[4].Name) + + gids, queriedUsers, err = modusdb.Query[User](db, modusdb.QueryParams{ + Filter: &modusdb.Filter{ + Field: "age", + String: modusdb.StringPredicate{ + // The reason its a string even for int is bc i cant tell if + // user wants to compare with 0 the number or didn't provide a value + // TODO: fix this + GreaterOrEqual: fmt.Sprintf("%d", 20), + }, + }, + }, db1.ID()) + + require.NoError(t, err) + require.Len(t, queriedUsers, 4) + require.Len(t, gids, 4) + require.Equal(t, "B", queriedUsers[0].Name) + require.Equal(t, "C", queriedUsers[1].Name) + require.Equal(t, "D", queriedUsers[2].Name) + require.Equal(t, "E", queriedUsers[3].Name) +} + +func TestQueryApiWithPaginiationAndSorting(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + users := []User{ + {Name: "A", Age: 10, ClerkId: "123"}, + {Name: "B", Age: 20, ClerkId: "123"}, + {Name: "C", Age: 30, ClerkId: "123"}, + {Name: "D", Age: 40, ClerkId: "123"}, + {Name: "E", Age: 50, ClerkId: "123"}, + } + + for _, user := range users { + _, _, err = modusdb.Create(db, user, db1.ID()) + require.NoError(t, err) + } + + gids, queriedUsers, err := modusdb.Query[User](db, modusdb.QueryParams{ + Filter: &modusdb.Filter{ + Field: "age", + String: modusdb.StringPredicate{ + GreaterOrEqual: fmt.Sprintf("%d", 20), + }, + }, + Pagination: &modusdb.Pagination{ + Limit: 3, + Offset: 1, + }, + }, db1.ID()) + + require.NoError(t, err) + require.Len(t, queriedUsers, 3) + require.Len(t, gids, 3) + require.Equal(t, "C", queriedUsers[0].Name) + require.Equal(t, "D", queriedUsers[1].Name) + require.Equal(t, "E", queriedUsers[2].Name) + + gids, queriedUsers, err = modusdb.Query[User](db, modusdb.QueryParams{ + Pagination: &modusdb.Pagination{ + Limit: 3, + Offset: 1, + }, + Sorting: &modusdb.Sorting{ + OrderAscField: "age", + }, + }, db1.ID()) + + require.NoError(t, err) + require.Len(t, queriedUsers, 3) + require.Len(t, gids, 3) + require.Equal(t, "B", queriedUsers[0].Name) + require.Equal(t, "C", queriedUsers[1].Name) + require.Equal(t, "D", queriedUsers[2].Name) +} + type Project struct { Gid uint64 `json:"gid,omitempty"` Name string `json:"name,omitempty"` @@ -298,7 +415,7 @@ func TestNestedObjectMutation(t *testing.T) { require.NoError(t, db1.DropData(ctx)) - branch := &Branch{ + branch := Branch{ Name: "B", ClerkId: "123", Proj: Project{ @@ -352,7 +469,7 @@ func TestLinkingObjectsByConstrainedFields(t *testing.T) { require.NoError(t, db1.DropData(ctx)) - projGid, project, err := modusdb.Create(db, &Project{ + projGid, project, err := modusdb.Create(db, Project{ Name: "P", ClerkId: "456", }, db1.ID()) @@ -361,7 +478,7 @@ func TestLinkingObjectsByConstrainedFields(t *testing.T) { require.Equal(t, "P", project.Name) require.Equal(t, project.Gid, projGid) - branch := &Branch{ + branch := Branch{ Name: "B", ClerkId: "123", Proj: Project{ @@ -415,7 +532,7 @@ func TestLinkingObjectsByGid(t *testing.T) { require.NoError(t, db1.DropData(ctx)) - projGid, project, err := modusdb.Create(db, &Project{ + projGid, project, err := modusdb.Create(db, Project{ Name: "P", ClerkId: "456", }, db1.ID()) @@ -424,7 +541,7 @@ func TestLinkingObjectsByGid(t *testing.T) { require.Equal(t, "P", project.Name) require.Equal(t, project.Gid, projGid) - branch := &Branch{ + branch := Branch{ Name: "B", ClerkId: "123", Proj: Project{ @@ -489,7 +606,7 @@ func TestNestedObjectMutationWithBadType(t *testing.T) { require.NoError(t, db1.DropData(ctx)) - branch := &BadBranch{ + branch := BadBranch{ Name: "B", ClerkId: "123", Proj: BadProject{ @@ -502,7 +619,7 @@ func TestNestedObjectMutationWithBadType(t *testing.T) { require.Error(t, err) require.Equal(t, fmt.Sprintf(modusdb.NoUniqueConstr, "BadProject"), err.Error()) - proj := &BadProject{ + proj := BadProject{ Name: "P", ClerkId: "456", } @@ -512,3 +629,120 @@ func TestNestedObjectMutationWithBadType(t *testing.T) { require.Equal(t, fmt.Sprintf(modusdb.NoUniqueConstr, "BadProject"), err.Error()) } + +type Document struct { + Gid uint64 `json:"gid,omitempty"` + Text string `json:"text,omitempty"` + TextVec []float32 `json:"textVec,omitempty" db:"constraint=vector"` +} + +func TestVectorIndexSearchTyped(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + documents := []Document{ + {Text: "apple", TextVec: []float32{0.1, 0.1, 0.0}}, + {Text: "banana", TextVec: []float32{0.0, 1.0, 0.0}}, + {Text: "carrot", TextVec: []float32{0.0, 0.0, 1.0}}, + {Text: "dog", TextVec: []float32{1.0, 1.0, 0.0}}, + {Text: "elephant", TextVec: []float32{0.0, 1.0, 1.0}}, + {Text: "fox", TextVec: []float32{1.0, 0.0, 1.0}}, + {Text: "gorilla", TextVec: []float32{1.0, 1.0, 1.0}}, + } + + for _, doc := range documents { + _, _, err = modusdb.Create(db, doc, db1.ID()) + require.NoError(t, err) + } + + const query = ` + { + documents(func: similar_to(Document.textVec, 5, "[0.1,0.1,0.1]")) { + Document.text + } + }` + + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, `{ + "documents":[ + {"Document.text":"apple"}, + {"Document.text":"dog"}, + {"Document.text":"elephant"}, + {"Document.text":"fox"}, + {"Document.text":"gorilla"} + ] + }`, string(resp.GetJson())) + + const query2 = ` + { + documents(func: type("Document")) @filter(similar_to(Document.textVec, 5, "[0.1,0.1,0.1]")) { + Document.text + } + }` + + resp, err = db1.Query(ctx, query2) + require.NoError(t, err) + require.JSONEq(t, `{ + "documents":[ + {"Document.text":"apple"}, + {"Document.text":"dog"}, + {"Document.text":"elephant"}, + {"Document.text":"fox"}, + {"Document.text":"gorilla"} + ] + }`, string(resp.GetJson())) +} + +func TestVectorIndexSearchWithQuery(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + documents := []Document{ + {Text: "apple", TextVec: []float32{0.1, 0.1, 0.0}}, + {Text: "banana", TextVec: []float32{0.0, 1.0, 0.0}}, + {Text: "carrot", TextVec: []float32{0.0, 0.0, 1.0}}, + {Text: "dog", TextVec: []float32{1.0, 1.0, 0.0}}, + {Text: "elephant", TextVec: []float32{0.0, 1.0, 1.0}}, + {Text: "fox", TextVec: []float32{1.0, 0.0, 1.0}}, + {Text: "gorilla", TextVec: []float32{1.0, 1.0, 1.0}}, + } + + for _, doc := range documents { + _, _, err = modusdb.Create(db, doc, db1.ID()) + require.NoError(t, err) + } + + gids, docs, err := modusdb.Query[Document](db, modusdb.QueryParams{ + Filter: &modusdb.Filter{ + Field: "textVec", + Vector: modusdb.VectorPredicate{ + SimilarTo: []float32{0.1, 0.1, 0.1}, + TopK: 5, + }, + }, + }, db1.ID()) + + require.NoError(t, err) + require.Len(t, docs, 5) + require.Len(t, gids, 5) + require.Equal(t, "apple", docs[0].Text) + require.Equal(t, "dog", docs[1].Text) + require.Equal(t, "elephant", docs[2].Text) + require.Equal(t, "fox", docs[3].Text) + require.Equal(t, "gorilla", docs[4].Text) +} diff --git a/api_types.go b/api_types.go index 860edda..3e53bf8 100644 --- a/api_types.go +++ b/api_types.go @@ -4,10 +4,12 @@ import ( "context" "encoding/binary" "fmt" + "strings" "time" "github.com/dgraph-io/dgo/v240/protos/api" "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/types" "github.com/dgraph-io/dgraph/v24/x" "github.com/twpayne/go-geom" "github.com/twpayne/go-geom/encoding/wkb" @@ -26,6 +28,51 @@ type ConstrainedField struct { Value any } +type QueryParams struct { + Filter *Filter + Pagination *Pagination + Sorting *Sorting +} + +type Filter struct { + Field string + String StringPredicate + Vector VectorPredicate + And *Filter + Or *Filter + Not *Filter +} + +type Pagination struct { + Limit int64 + Offset int64 + After string +} + +type Sorting struct { + OrderAscField string + OrderDescField string + OrderDescFirst bool +} + +type StringPredicate struct { + Equals string + LessThan string + LessOrEqual string + GreaterThan string + GreaterOrEqual string + AllOfTerms []string + AnyOfTerms []string + AllOfText []string + AnyOfText []string + RegExp string +} + +type VectorPredicate struct { + SimilarTo []float32 + TopK int64 +} + type ModusDbOption func(*modusDbOptions) type modusDbOptions struct { @@ -110,6 +157,16 @@ func valueToApiVal(v any) (*api.Value, error) { return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: float64(val)}}, nil case float64: return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}}, nil + case []float32: + return &api.Value{Val: &api.Value_Vfloat32Val{ + Vfloat32Val: types.FloatArrayAsBytes(val)}}, nil + case []float64: + float32Slice := make([]float32, len(val)) + for i, v := range val { + float32Slice[i] = float32(v) + } + return &api.Value{Val: &api.Value_Vfloat32Val{ + Vfloat32Val: types.FloatArrayAsBytes(float32Slice)}}, nil case []byte: return &api.Value{Val: &api.Value_BytesVal{BytesVal: val}}, nil case time.Time: @@ -130,3 +187,54 @@ func valueToApiVal(v any) (*api.Value, error) { return nil, fmt.Errorf("unsupported type %T", v) } } + +func filterToQueryFunc(typeName string, f Filter) QueryFunc { + // Handle logical operators first + if f.And != nil { + return And(filterToQueryFunc(typeName, *f.And)) + } + if f.Or != nil { + return Or(filterToQueryFunc(typeName, *f.Or)) + } + if f.Not != nil { + return Not(filterToQueryFunc(typeName, *f.Not)) + } + + // Handle field predicates + if f.String.Equals != "" { + return buildEqQuery(getPredicateName(typeName, f.Field), f.String.Equals) + } + if len(f.String.AllOfTerms) != 0 { + return buildAllOfTermsQuery(getPredicateName(typeName, f.Field), strings.Join(f.String.AllOfTerms, " ")) + } + if len(f.String.AnyOfTerms) != 0 { + return buildAnyOfTermsQuery(getPredicateName(typeName, f.Field), strings.Join(f.String.AnyOfTerms, " ")) + } + if len(f.String.AllOfText) != 0 { + return buildAllOfTextQuery(getPredicateName(typeName, f.Field), strings.Join(f.String.AllOfText, " ")) + } + if len(f.String.AnyOfText) != 0 { + return buildAnyOfTextQuery(getPredicateName(typeName, f.Field), strings.Join(f.String.AnyOfText, " ")) + } + if f.String.RegExp != "" { + return buildRegExpQuery(getPredicateName(typeName, f.Field), f.String.RegExp) + } + if f.String.LessThan != "" { + return buildLtQuery(getPredicateName(typeName, f.Field), f.String.LessThan) + } + if f.String.LessOrEqual != "" { + return buildLeQuery(getPredicateName(typeName, f.Field), f.String.LessOrEqual) + } + if f.String.GreaterThan != "" { + return buildGtQuery(getPredicateName(typeName, f.Field), f.String.GreaterThan) + } + if f.String.GreaterOrEqual != "" { + return buildGeQuery(getPredicateName(typeName, f.Field), f.String.GreaterOrEqual) + } + if f.Vector.SimilarTo != nil { + return buildSimilarToQuery(getPredicateName(typeName, f.Field), f.Vector.TopK, f.Vector.SimilarTo) + } + + // Return empty query if no conditions match + return func() string { return "" } +}