diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d92b85..e92097c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ - feat: add readfrom json tag to support reverse edges [#49](https://github.com/hypermodeinc/modusDB/pull/49) +- chore: Refactoring package management #51 [#51](https://github.com/hypermodeinc/modusDB/pull/51) + ## 2025-01-02 - Version 0.1.0 Baseline for the changelog. diff --git a/api.go b/api.go index 652ea19..beadf9f 100644 --- a/api.go +++ b/api.go @@ -14,6 +14,7 @@ import ( "github.com/dgraph-io/dgraph/v24/dql" "github.com/dgraph-io/dgraph/v24/schema" + "github.com/hypermodeinc/modusdb/api/utils" ) func Create[T any](db *DB, object T, ns ...uint64) (uint64, T, error) { @@ -34,7 +35,7 @@ func Create[T any](db *DB, object T, ns ...uint64) (uint64, T, error) { dms := make([]*dql.Mutation, 0) sch := &schema.ParsedSchema{} - err = generateCreateDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) + err = generateSetDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) if err != nil { return 0, object, err } @@ -66,14 +67,14 @@ func Upsert[T any](db *DB, object T, ns ...uint64) (uint64, T, bool, error) { return 0, object, false, err } - gid, cf, err := getUniqueConstraint[T](object) + gid, cf, err := GetUniqueConstraint[T](object) if err != nil { return 0, object, false, err } dms := make([]*dql.Mutation, 0) sch := &schema.ParsedSchema{} - err = generateCreateDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) + err = generateSetDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) if err != nil { return 0, object, false, err } @@ -83,19 +84,14 @@ func Upsert[T any](db *DB, object T, ns ...uint64) (uint64, T, bool, error) { return 0, object, false, err } - if gid != 0 { - gid, _, err = getByGidWithObject[T](ctx, n, gid, object) - if err != nil && err != ErrNoObjFound { - return 0, object, false, err - } - wasFound = err == nil - } else if cf != nil { - gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) - if err != nil && err != ErrNoObjFound { + if gid != 0 || cf != nil { + gid, err = getExistingObject[T](ctx, n, gid, cf, object) + if err != nil && err != utils.ErrNoObjFound { return 0, object, false, err } wasFound = err == nil } + if gid == 0 { gid, err = db.z.nextUID() if err != nil { @@ -104,7 +100,7 @@ func Upsert[T any](db *DB, object T, ns ...uint64) (uint64, T, bool, error) { } dms = make([]*dql.Mutation, 0) - err = generateCreateDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) + err = generateSetDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) if err != nil { return 0, object, false, err } diff --git a/api/mutations/mutations.go b/api/mutations/mutations.go new file mode 100644 index 0000000..4e4fedf --- /dev/null +++ b/api/mutations/mutations.go @@ -0,0 +1,71 @@ +package mutations + +import ( + "fmt" + "reflect" + "strings" + + "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/schema" + "github.com/hypermodeinc/modusdb/api/utils" +) + +func HandleReverseEdge(jsonName string, value reflect.Type, nsId uint64, sch *schema.ParsedSchema, + jsonToReverseEdgeTags map[string]string) error { + if jsonToReverseEdgeTags[jsonName] == "" { + return nil + } + + if value.Kind() != reflect.Slice || value.Elem().Kind() != reflect.Struct { + return fmt.Errorf("reverse edge %s should be a slice of structs", jsonName) + } + + reverseEdge := jsonToReverseEdgeTags[jsonName] + typeName := strings.Split(reverseEdge, ".")[0] + u := &pb.SchemaUpdate{ + Predicate: utils.AddNamespace(nsId, reverseEdge), + ValueType: pb.Posting_UID, + Directive: pb.SchemaUpdate_REVERSE, + } + + sch.Preds = append(sch.Preds, u) + sch.Types = append(sch.Types, &pb.TypeUpdate{ + TypeName: utils.AddNamespace(nsId, typeName), + Fields: []*pb.SchemaUpdate{u}, + }) + return nil +} + +func CreateNQuadAndSchema(value any, gid uint64, jsonName string, t reflect.Type, + nsId uint64) (*api.NQuad, *pb.SchemaUpdate, error) { + valType, err := utils.ValueToPosting_ValType(value) + if err != nil { + return nil, nil, err + } + + val, err := utils.ValueToApiVal(value) + if err != nil { + return nil, nil, err + } + + nquad := &api.NQuad{ + Namespace: nsId, + Subject: fmt.Sprint(gid), + Predicate: utils.GetPredicateName(t.Name(), jsonName), + } + + u := &pb.SchemaUpdate{ + Predicate: utils.AddNamespace(nsId, utils.GetPredicateName(t.Name(), jsonName)), + ValueType: valType, + } + + if valType == pb.Posting_UID { + nquad.ObjectId = fmt.Sprint(value) + u.Directive = pb.SchemaUpdate_REVERSE + } else { + nquad.ObjectValue = val + } + + return nquad, u, nil +} diff --git a/api/query_gen/dql_query.go b/api/query_gen/dql_query.go new file mode 100644 index 0000000..5922342 --- /dev/null +++ b/api/query_gen/dql_query.go @@ -0,0 +1,173 @@ +package query_gen + +import ( + "fmt" + "strconv" + "strings" +) + +type QueryFunc func() string + +const ( + ObjQuery = ` + { + obj(func: %s) { + gid: uid + expand(_all_) { + gid: uid + expand(_all_) + dgraph.type + } + dgraph.type + %s + } + } + ` + + ObjsQuery = ` + { + objs(func: type("%s")%s) @filter(%s) { + gid: uid + expand(_all_) { + gid: uid + expand(_all_) + dgraph.type + } + dgraph.type + %s + } + } + ` + + ReverseEdgeQuery = ` + %s: ~%s { + gid: uid + expand(_all_) + dgraph.type + } + ` + + FuncUid = `uid(%d)` + FuncEq = `eq(%s, %s)` + FuncSimilarTo = `similar_to(%s, %d, "[%s]")` + FuncAllOfTerms = `allofterms(%s, "%s")` + FuncAnyOfTerms = `anyofterms(%s, "%s")` + FuncAllOfText = `alloftext(%s, "%s")` + FuncAnyOfText = `anyoftext(%s, "%s")` + FuncRegExp = `regexp(%s, /%s/)` + FuncLe = `le(%s, %s)` + FuncGe = `ge(%s, %s)` + FuncGt = `gt(%s, %s)` + FuncLt = `lt(%s, %s)` +) + +func BuildUidQuery(gid uint64) QueryFunc { + return func() string { + return fmt.Sprintf(FuncUid, gid) + } +} + +func BuildEqQuery(key string, value any) QueryFunc { + return func() string { + return fmt.Sprintf(FuncEq, key, value) + } +} + +func BuildSimilarToQuery(indexAttr string, topK int64, vec []float32) QueryFunc { + vecStrArr := make([]string, len(vec)) + for i := range vec { + vecStrArr[i] = strconv.FormatFloat(float64(vec[i]), 'f', -1, 32) + } + vecStr := strings.Join(vecStrArr, ",") + return func() string { + return fmt.Sprintf(FuncSimilarTo, indexAttr, topK, vecStr) + } +} + +func BuildAllOfTermsQuery(attr string, terms string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncAllOfTerms, attr, terms) + } +} + +func BuildAnyOfTermsQuery(attr string, terms string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncAnyOfTerms, attr, terms) + } +} + +func BuildAllOfTextQuery(attr, text string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncAllOfText, attr, text) + } +} + +func BuildAnyOfTextQuery(attr, text string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncAnyOfText, attr, text) + } +} + +func BuildRegExpQuery(attr, pattern string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncRegExp, attr, pattern) + } +} + +func BuildLeQuery(attr, value string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncLe, attr, value) + } +} + +func BuildGeQuery(attr, value string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncGe, attr, value) + } +} + +func BuildGtQuery(attr, value string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncGt, attr, value) + } +} + +func BuildLtQuery(attr, value string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncLt, attr, value) + } +} + +func And(qfs ...QueryFunc) QueryFunc { + return func() string { + qs := make([]string, len(qfs)) + for i, qf := range qfs { + qs[i] = qf() + } + return strings.Join(qs, " AND ") + } +} + +func Or(qfs ...QueryFunc) QueryFunc { + return func() string { + qs := make([]string, len(qfs)) + for i, qf := range qfs { + qs[i] = qf() + } + return strings.Join(qs, " OR ") + } +} + +func Not(qf QueryFunc) QueryFunc { + return func() string { + return "NOT " + qf() + } +} + +func FormatObjQuery(qf QueryFunc, extraFields string) string { + return fmt.Sprintf(ObjQuery, qf(), extraFields) +} + +func FormatObjsQuery(typeName string, qf QueryFunc, paginationAndSorting string, extraFields string) string { + return fmt.Sprintf(ObjsQuery, typeName, paginationAndSorting, qf(), extraFields) +} diff --git a/api/utils/dgraph.go b/api/utils/dgraph.go new file mode 100644 index 0000000..fc49050 --- /dev/null +++ b/api/utils/dgraph.go @@ -0,0 +1,147 @@ +package utils + +import ( + "encoding/binary" + "fmt" + "time" + + "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/types" + "github.com/twpayne/go-geom" + "github.com/twpayne/go-geom/encoding/wkb" +) + +func addIndex(u *pb.SchemaUpdate, index string, uniqueConstraintExists bool) bool { + u.Directive = pb.SchemaUpdate_INDEX + switch index { + case "exact": + u.Tokenizer = []string{"exact"} + case "term": + u.Tokenizer = []string{"term"} + case "hash": + u.Tokenizer = []string{"hash"} + case "unique": + u.Tokenizer = []string{"exact"} + u.Unique = true + u.Upsert = true + uniqueConstraintExists = true + case "fulltext": + u.Tokenizer = []string{"fulltext"} + case "trigram": + u.Tokenizer = []string{"trigram"} + case "vector": + u.IndexSpecs = []*pb.VectorIndexSpec{ + { + Name: "hnsw", + Options: []*pb.OptionPair{ + { + Key: "metric", + Value: "cosine", + }, + }, + }, + } + default: + return uniqueConstraintExists + } + return uniqueConstraintExists +} + +func ValueToPosting_ValType(v any) (pb.Posting_ValType, error) { + switch v.(type) { + case string: + return pb.Posting_STRING, nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32: + return pb.Posting_INT, nil + case uint64: + return pb.Posting_UID, nil + case bool: + return pb.Posting_BOOL, nil + case float32, float64: + return pb.Posting_FLOAT, nil + case []byte: + return pb.Posting_BINARY, nil + case time.Time: + return pb.Posting_DATETIME, nil + case geom.Point: + return pb.Posting_GEO, nil + case []float32, []float64: + return pb.Posting_VFLOAT, nil + default: + return pb.Posting_DEFAULT, fmt.Errorf("unsupported type %T", v) + } +} + +func ValueToApiVal(v any) (*api.Value, error) { + switch val := v.(type) { + case string: + return &api.Value{Val: &api.Value_StrVal{StrVal: val}}, nil + case int: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int8: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int16: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int32: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int64: + return &api.Value{Val: &api.Value_IntVal{IntVal: val}}, nil + case uint8: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint16: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint32: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint64: + return &api.Value{Val: &api.Value_UidVal{UidVal: val}}, nil + case bool: + return &api.Value{Val: &api.Value_BoolVal{BoolVal: val}}, nil + case float32: + return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: float64(val)}}, nil + case float64: + return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}}, nil + case []float32: + return &api.Value{Val: &api.Value_Vfloat32Val{ + Vfloat32Val: types.FloatArrayAsBytes(val)}}, nil + case []float64: + float32Slice := make([]float32, len(val)) + for i, v := range val { + float32Slice[i] = float32(v) + } + return &api.Value{Val: &api.Value_Vfloat32Val{ + Vfloat32Val: types.FloatArrayAsBytes(float32Slice)}}, nil + case []byte: + return &api.Value{Val: &api.Value_BytesVal{BytesVal: val}}, nil + case time.Time: + bytes, err := val.MarshalBinary() + if err != nil { + return nil, err + } + return &api.Value{Val: &api.Value_DateVal{DateVal: bytes}}, nil + case geom.Point: + bytes, err := wkb.Marshal(&val, binary.LittleEndian) + if err != nil { + return nil, err + } + return &api.Value{Val: &api.Value_GeoVal{GeoVal: bytes}}, nil + case uint: + return &api.Value{Val: &api.Value_DefaultVal{DefaultVal: fmt.Sprint(v)}}, nil + default: + return nil, fmt.Errorf("unsupported type %T", v) + } +} + +func HandleConstraints(u *pb.SchemaUpdate, jsonToDbTags map[string]*DbTag, jsonName string, + valType pb.Posting_ValType, uniqueConstraintFound bool) (bool, error) { + if jsonToDbTags[jsonName] == nil { + return uniqueConstraintFound, nil + } + + constraint := jsonToDbTags[jsonName].Constraint + if constraint == "vector" && valType != pb.Posting_VFLOAT { + return false, fmt.Errorf("vector index can only be applied to []float values") + } + + return addIndex(u, constraint, uniqueConstraintFound), nil +} diff --git a/api/utils/reflect.go b/api/utils/reflect.go new file mode 100644 index 0000000..a4a57e0 --- /dev/null +++ b/api/utils/reflect.go @@ -0,0 +1,210 @@ +package utils + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +type DbTag struct { + Constraint string +} + +func GetFieldTags(t reflect.Type) (fieldToJsonTags map[string]string, + jsonToDbTags map[string]*DbTag, jsonToReverseEdgeTags map[string]string, err error) { + + fieldToJsonTags = make(map[string]string) + jsonToDbTags = make(map[string]*DbTag) + jsonToReverseEdgeTags = make(map[string]string) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + jsonTag := field.Tag.Get("json") + if jsonTag == "" { + return nil, nil, nil, fmt.Errorf("field %s has no json tag", field.Name) + } + jsonName := strings.Split(jsonTag, ",")[0] + fieldToJsonTags[field.Name] = jsonName + + reverseEdgeTag := field.Tag.Get("readFrom") + if reverseEdgeTag != "" { + typeAndField := strings.Split(reverseEdgeTag, ",") + if len(typeAndField) != 2 { + return nil, nil, nil, fmt.Errorf(`field %s has invalid readFrom tag, + expected format is type=,field=`, field.Name) + } + t := strings.Split(typeAndField[0], "=")[1] + f := strings.Split(typeAndField[1], "=")[1] + jsonToReverseEdgeTags[jsonName] = GetPredicateName(t, f) + } + + dbConstraintsTag := field.Tag.Get("db") + if dbConstraintsTag != "" { + jsonToDbTags[jsonName] = &DbTag{} + dbTagsSplit := strings.Split(dbConstraintsTag, ",") + for _, dbTag := range dbTagsSplit { + split := strings.Split(dbTag, "=") + if split[0] == "constraint" { + jsonToDbTags[jsonName].Constraint = split[1] + } + } + } + } + return fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, nil +} + +func GetJsonTagToValues(object any, fieldToJsonTags map[string]string) map[string]any { + values := make(map[string]any) + v := reflect.ValueOf(object) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + for fieldName, jsonName := range fieldToJsonTags { + fieldValue := v.FieldByName(fieldName) + values[jsonName] = fieldValue.Interface() + + } + return values +} + +func CreateDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, depth int) reflect.Type { + fields := make([]reflect.StructField, 0, len(fieldToJsonTags)) + for fieldName, jsonName := range fieldToJsonTags { + field, _ := t.FieldByName(fieldName) + if fieldName != "Gid" { + if field.Type.Kind() == reflect.Struct { + if depth <= 1 { + nestedFieldToJsonTags, _, _, _ := GetFieldTags(field.Type) + nestedType := CreateDynamicStruct(field.Type, nestedFieldToJsonTags, depth+1) + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: nestedType, + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } + } else if field.Type.Kind() == reflect.Ptr && + field.Type.Elem().Kind() == reflect.Struct { + nestedFieldToJsonTags, _, _, _ := GetFieldTags(field.Type.Elem()) + nestedType := CreateDynamicStruct(field.Type.Elem(), nestedFieldToJsonTags, depth+1) + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: reflect.PointerTo(nestedType), + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } else if field.Type.Kind() == reflect.Slice && + field.Type.Elem().Kind() == reflect.Struct { + nestedFieldToJsonTags, _, _, _ := GetFieldTags(field.Type.Elem()) + nestedType := CreateDynamicStruct(field.Type.Elem(), nestedFieldToJsonTags, depth+1) + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: reflect.SliceOf(nestedType), + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } else { + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: field.Type, + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } + + } + } + fields = append(fields, reflect.StructField{ + Name: "Gid", + Type: reflect.TypeOf(""), + Tag: reflect.StructTag(`json:"gid"`), + }, reflect.StructField{ + Name: "DgraphType", + Type: reflect.TypeOf([]string{}), + Tag: reflect.StructTag(`json:"dgraph.type"`), + }) + return reflect.StructOf(fields) +} + +func MapDynamicToFinal(dynamic any, final any, isNested bool) (uint64, error) { + vFinal := reflect.ValueOf(final).Elem() + vDynamic := reflect.ValueOf(dynamic).Elem() + + gid := uint64(0) + + for i := 0; i < vDynamic.NumField(); i++ { + + dynamicField := vDynamic.Type().Field(i) + dynamicFieldType := dynamicField.Type + dynamicValue := vDynamic.Field(i) + + var finalField reflect.Value + if dynamicField.Name == "Gid" { + finalField = vFinal.FieldByName("Gid") + gidStr := dynamicValue.String() + gid, _ = strconv.ParseUint(gidStr, 0, 64) + } else if dynamicField.Name == "DgraphType" { + fieldArrInterface := dynamicValue.Interface() + fieldArr, ok := fieldArrInterface.([]string) + if ok { + if len(fieldArr) == 0 { + if !isNested { + return 0, ErrNoObjFound + } else { + continue + } + } + } else { + return 0, fmt.Errorf("DgraphType field should be an array of strings") + } + } else { + finalField = vFinal.FieldByName(dynamicField.Name) + } + if dynamicFieldType.Kind() == reflect.Struct { + _, err := MapDynamicToFinal(dynamicValue.Addr().Interface(), finalField.Addr().Interface(), true) + if err != nil { + return 0, err + } + } else if dynamicFieldType.Kind() == reflect.Ptr && + dynamicFieldType.Elem().Kind() == reflect.Struct { + // if field is a pointer, find if the underlying is a struct + _, err := MapDynamicToFinal(dynamicValue.Interface(), finalField.Interface(), true) + if err != nil { + return 0, err + } + } else if dynamicFieldType.Kind() == reflect.Slice && + dynamicFieldType.Elem().Kind() == reflect.Struct { + for j := 0; j < dynamicValue.Len(); j++ { + sliceElem := dynamicValue.Index(j).Addr().Interface() + finalSliceElem := reflect.New(finalField.Type().Elem()).Elem() + _, err := MapDynamicToFinal(sliceElem, finalSliceElem.Addr().Interface(), true) + if err != nil { + return 0, err + } + finalField.Set(reflect.Append(finalField, finalSliceElem)) + } + } else { + if finalField.IsValid() && finalField.CanSet() { + // if field name is gid, convert it to uint64 + if dynamicField.Name == "Gid" { + finalField.SetUint(gid) + } else { + finalField.Set(dynamicValue) + } + } + } + } + return gid, nil +} + +func ConvertDynamicToTyped[T any](obj any, t reflect.Type) (uint64, T, error) { + var result T + finalObject := reflect.New(t).Interface() + gid, err := MapDynamicToFinal(obj, finalObject, false) + if err != nil { + return 0, result, err + } + + if typedPtr, ok := finalObject.(*T); ok { + return gid, *typedPtr, nil + } else if dirType, ok := finalObject.(T); ok { + return gid, dirType, nil + } + return 0, result, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) +} diff --git a/utils.go b/api/utils/utils.go similarity index 62% rename from utils.go rename to api/utils/utils.go index e571bb2..ba726ed 100644 --- a/utils.go +++ b/api/utils/utils.go @@ -7,7 +7,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package modusdb +package utils import ( "fmt" @@ -15,10 +15,15 @@ import ( "github.com/dgraph-io/dgraph/v24/x" ) -func getPredicateName(typeName, fieldName string) string { +var ( + ErrNoObjFound = fmt.Errorf("no object found") + NoUniqueConstr = "unique constraint not defined for any field on type %s" +) + +func GetPredicateName(typeName, fieldName string) string { return fmt.Sprint(typeName, ".", fieldName) } -func addNamespace(ns uint64, pred string) string { +func AddNamespace(ns uint64, pred string) string { return x.NamespaceAttr(ns, pred) } diff --git a/api_dql.go b/api_dql.go deleted file mode 100644 index 6c63171..0000000 --- a/api_dql.go +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Copyright 2025 Hypermode Inc. - * Licensed under the terms of the Apache License, Version 2.0 - * See the LICENSE file that accompanied this code for further details. - * - * SPDX-FileCopyrightText: 2025 Hypermode Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package modusdb - -import ( - "fmt" - "strconv" - "strings" -) - -type QueryFunc func() string - -const ( - objQuery = ` - { - obj(func: %s) { - gid: uid - expand(_all_) { - gid: uid - expand(_all_) - dgraph.type - } - dgraph.type - %s - } - } - ` - - objsQuery = ` - { - objs(func: type("%s")%s) @filter(%s) { - gid: uid - expand(_all_) { - gid: uid - expand(_all_) - dgraph.type - } - dgraph.type - %s - } - } - ` - - reverseEdgeQuery = ` - %s: ~%s { - gid: uid - expand(_all_) - dgraph.type - } - ` - - funcUid = `uid(%d)` - funcEq = `eq(%s, %s)` - funcSimilarTo = `similar_to(%s, %d, "[%s]")` - funcAllOfTerms = `allofterms(%s, "%s")` - funcAnyOfTerms = `anyofterms(%s, "%s")` - funcAllOfText = `alloftext(%s, "%s")` - funcAnyOfText = `anyoftext(%s, "%s")` - funcRegExp = `regexp(%s, /%s/)` - funcLe = `le(%s, %s)` - funcGe = `ge(%s, %s)` - funcGt = `gt(%s, %s)` - funcLt = `lt(%s, %s)` -) - -func buildUidQuery(gid uint64) QueryFunc { - return func() string { - return fmt.Sprintf(funcUid, gid) - } -} - -func buildEqQuery(key string, value any) QueryFunc { - return func() string { - return fmt.Sprintf(funcEq, key, value) - } -} - -func buildSimilarToQuery(indexAttr string, topK int64, vec []float32) QueryFunc { - vecStrArr := make([]string, len(vec)) - for i := range vec { - vecStrArr[i] = strconv.FormatFloat(float64(vec[i]), 'f', -1, 32) - } - vecStr := strings.Join(vecStrArr, ",") - return func() string { - return fmt.Sprintf(funcSimilarTo, indexAttr, topK, vecStr) - } -} - -func buildAllOfTermsQuery(attr string, terms string) QueryFunc { - return func() string { - return fmt.Sprintf(funcAllOfTerms, attr, terms) - } -} - -func buildAnyOfTermsQuery(attr string, terms string) QueryFunc { - return func() string { - return fmt.Sprintf(funcAnyOfTerms, attr, terms) - } -} - -func buildAllOfTextQuery(attr, text string) QueryFunc { - return func() string { - return fmt.Sprintf(funcAllOfText, attr, text) - } -} - -func buildAnyOfTextQuery(attr, text string) QueryFunc { - return func() string { - return fmt.Sprintf(funcAnyOfText, attr, text) - } -} - -func buildRegExpQuery(attr, pattern string) QueryFunc { - return func() string { - return fmt.Sprintf(funcRegExp, attr, pattern) - } -} - -func buildLeQuery(attr, value string) QueryFunc { - return func() string { - return fmt.Sprintf(funcLe, attr, value) - } -} - -func buildGeQuery(attr, value string) QueryFunc { - return func() string { - return fmt.Sprintf(funcGe, attr, value) - } -} - -func buildGtQuery(attr, value string) QueryFunc { - return func() string { - return fmt.Sprintf(funcGt, attr, value) - } -} - -func buildLtQuery(attr, value string) QueryFunc { - return func() string { - return fmt.Sprintf(funcLt, attr, value) - } -} - -func And(qfs ...QueryFunc) QueryFunc { - return func() string { - qs := make([]string, len(qfs)) - for i, qf := range qfs { - qs[i] = qf() - } - return strings.Join(qs, " AND ") - } -} - -func Or(qfs ...QueryFunc) QueryFunc { - return func() string { - qs := make([]string, len(qfs)) - for i, qf := range qfs { - qs[i] = qf() - } - return strings.Join(qs, " OR ") - } -} - -func Not(qf QueryFunc) QueryFunc { - return func() string { - return "NOT " + qf() - } -} - -func formatObjQuery(qf QueryFunc, extraFields string) string { - return fmt.Sprintf(objQuery, qf(), extraFields) -} - -func formatObjsQuery(typeName string, qf QueryFunc, paginationAndSorting string, extraFields string) string { - return fmt.Sprintf(objsQuery, typeName, paginationAndSorting, qf(), extraFields) -} - -// Helper function to combine multiple filters -func filtersToQueryFunc(typeName string, filter Filter) QueryFunc { - return filterToQueryFunc(typeName, filter) -} - -func paginationToQueryString(p Pagination) string { - paginationStr := "" - if p.Limit > 0 { - paginationStr += ", " + fmt.Sprintf("first: %d", p.Limit) - } - if p.Offset > 0 { - paginationStr += ", " + fmt.Sprintf("offset: %d", p.Offset) - } else if p.After != "" { - paginationStr += ", " + fmt.Sprintf("after: %s", p.After) - } - if paginationStr == "" { - return "" - } - return paginationStr -} - -func sortingToQueryString(typeName string, s Sorting) string { - if s.OrderAscField == "" && s.OrderDescField == "" { - return "" - } - - var parts []string - first, second := s.OrderDescField, s.OrderAscField - firstOp, secondOp := "orderdesc", "orderasc" - - if !s.OrderDescFirst { - first, second = s.OrderAscField, s.OrderDescField - firstOp, secondOp = "orderasc", "orderdesc" - } - - if first != "" { - parts = append(parts, fmt.Sprintf("%s: %s", firstOp, getPredicateName(typeName, first))) - } - if second != "" { - parts = append(parts, fmt.Sprintf("%s: %s", secondOp, getPredicateName(typeName, second))) - } - - return ", " + strings.Join(parts, ", ") -} diff --git a/api_mutate_helper.go b/api_mutate_helper.go deleted file mode 100644 index f90c258..0000000 --- a/api_mutate_helper.go +++ /dev/null @@ -1,299 +0,0 @@ -/* - * Copyright 2025 Hypermode Inc. - * Licensed under the terms of the Apache License, Version 2.0 - * See the LICENSE file that accompanied this code for further details. - * - * SPDX-FileCopyrightText: 2025 Hypermode Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package modusdb - -import ( - "context" - "fmt" - "reflect" - "strings" - - "github.com/dgraph-io/dgo/v240/protos/api" - "github.com/dgraph-io/dgraph/v24/dql" - "github.com/dgraph-io/dgraph/v24/protos/pb" - "github.com/dgraph-io/dgraph/v24/query" - "github.com/dgraph-io/dgraph/v24/schema" - "github.com/dgraph-io/dgraph/v24/worker" - "github.com/dgraph-io/dgraph/v24/x" -) - -func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object T, - gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error { - t := reflect.TypeOf(object) - if t.Kind() != reflect.Struct { - return fmt.Errorf("expected struct, got %s", t.Kind()) - } - - fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, err := getFieldTags(t) - if err != nil { - return err - } - jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) - - nquads := make([]*api.NQuad, 0) - uniqueConstraintFound := false - for jsonName, value := range jsonTagToValue { - var val *api.Value - var valType pb.Posting_ValType - - reflectValueType := reflect.TypeOf(value) - var nquad *api.NQuad - - if jsonToReverseEdgeTags[jsonName] != "" { - if reflectValueType.Kind() != reflect.Slice || reflectValueType.Elem().Kind() != reflect.Struct { - return fmt.Errorf("reverse edge %s should be a slice of structs", jsonName) - } - reverseEdge := jsonToReverseEdgeTags[jsonName] - typeName := strings.Split(reverseEdge, ".")[0] - u := &pb.SchemaUpdate{ - Predicate: addNamespace(n.id, reverseEdge), - ValueType: pb.Posting_UID, - Directive: pb.SchemaUpdate_REVERSE, - } - sch.Preds = append(sch.Preds, u) - sch.Types = append(sch.Types, &pb.TypeUpdate{ - TypeName: addNamespace(n.id, typeName), - Fields: []*pb.SchemaUpdate{u}, - }) - continue - } - if jsonName == "gid" { - uniqueConstraintFound = true - continue - } - - if reflectValueType.Kind() == reflect.Struct { - value = reflect.ValueOf(value).Interface() - newGid, err := getUidOrMutate(ctx, n.db, n, value) - if err != nil { - return err - } - value = newGid - } else if reflectValueType.Kind() == reflect.Pointer { - // dereference the pointer - reflectValueType = reflectValueType.Elem() - if reflectValueType.Kind() == reflect.Struct { - // convert value to pointer, and then dereference - value = reflect.ValueOf(value).Elem().Interface() - newGid, err := getUidOrMutate(ctx, n.db, n, value) - if err != nil { - return err - } - value = newGid - } - } - valType, err = valueToPosting_ValType(value) - if err != nil { - return err - } - val, err = valueToApiVal(value) - if err != nil { - return err - } - - nquad = &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: getPredicateName(t.Name(), jsonName), - } - - u := &pb.SchemaUpdate{ - Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), - ValueType: valType, - } - - if valType == pb.Posting_UID { - nquad.ObjectId = fmt.Sprint(value) - u.Directive = pb.SchemaUpdate_REVERSE - } else { - nquad.ObjectValue = val - } - - if jsonToDbTags[jsonName] != nil { - constraint := jsonToDbTags[jsonName].constraint - if constraint == "vector" && valType != pb.Posting_VFLOAT { - return fmt.Errorf("vector index can only be applied to []float values") - } - uniqueConstraintFound = addIndex(u, constraint, uniqueConstraintFound) - } - - sch.Preds = append(sch.Preds, u) - nquads = append(nquads, nquad) - } - if !uniqueConstraintFound { - return fmt.Errorf(NoUniqueConstr, t.Name()) - } - sch.Types = append(sch.Types, &pb.TypeUpdate{ - TypeName: addNamespace(n.id, t.Name()), - Fields: sch.Preds, - }) - - val, err := valueToApiVal(t.Name()) - if err != nil { - return err - } - typeNquad := &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: "dgraph.type", - ObjectValue: val, - } - nquads = append(nquads, typeNquad) - - *dms = append(*dms, &dql.Mutation{ - Set: nquads, - }) - - return nil -} - -func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { - return []*dql.Mutation{{ - Del: []*api.NQuad{ - { - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: x.Star, - ObjectValue: &api.Value{ - Val: &api.Value_DefaultVal{DefaultVal: x.Star}, - }, - }, - }, - }} -} - -func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { - edges, err := query.ToDirectedEdges(dms, nil) - if err != nil { - return err - } - - if !db.isOpen { - return ErrClosedDB - } - - startTs, err := db.z.nextTs() - if err != nil { - return err - } - commitTs, err := db.z.nextTs() - if err != nil { - return err - } - - m := &pb.Mutations{ - GroupId: 1, - StartTs: startTs, - Edges: edges, - } - m.Edges, err = query.ExpandEdges(ctx, m) - if err != nil { - return fmt.Errorf("error expanding edges: %w", err) - } - - p := &pb.Proposal{Mutations: m, StartTs: startTs} - if err := worker.ApplyMutations(ctx, p); err != nil { - return err - } - - return worker.ApplyCommited(ctx, &pb.OracleDelta{ - Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, - }) -} - -func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { - gid, cf, err := getUniqueConstraint[T](object) - if err != nil { - return 0, err - } - - dms := make([]*dql.Mutation, 0) - sch := &schema.ParsedSchema{} - err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) - if err != nil { - return 0, err - } - - err = n.alterSchemaWithParsed(ctx, sch) - if err != nil { - return 0, err - } - if gid != 0 { - gid, _, err = getByGidWithObject[T](ctx, n, gid, object) - if err != nil && err != ErrNoObjFound { - return 0, err - } - if err == nil { - return gid, nil - } - } else if cf != nil { - gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) - if err != nil && err != ErrNoObjFound { - return 0, err - } - if err == nil { - return gid, nil - } - } - - gid, err = db.z.nextUID() - if err != nil { - return 0, err - } - - dms = make([]*dql.Mutation, 0) - err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) - if err != nil { - return 0, err - } - - err = applyDqlMutations(ctx, db, dms) - if err != nil { - return 0, err - } - - return gid, nil -} - -func addIndex(u *pb.SchemaUpdate, index string, uniqueConstraintExists bool) bool { - u.Directive = pb.SchemaUpdate_INDEX - switch index { - case "exact": - u.Tokenizer = []string{"exact"} - case "term": - u.Tokenizer = []string{"term"} - case "hash": - u.Tokenizer = []string{"hash"} - case "unique": - u.Tokenizer = []string{"exact"} - u.Unique = true - u.Upsert = true - uniqueConstraintExists = true - case "fulltext": - u.Tokenizer = []string{"fulltext"} - case "trigram": - u.Tokenizer = []string{"trigram"} - case "vector": - u.IndexSpecs = []*pb.VectorIndexSpec{ - { - Name: "hnsw", - Options: []*pb.OptionPair{ - { - Key: "metric", - Value: "cosine", - }, - }, - }, - } - default: - return uniqueConstraintExists - } - return uniqueConstraintExists -} diff --git a/api_mutation_gen.go b/api_mutation_gen.go new file mode 100644 index 0000000..3bfdc6f --- /dev/null +++ b/api_mutation_gen.go @@ -0,0 +1,121 @@ +/* + * Copyright 2025 Hypermode Inc. + * Licensed under the terms of the Apache License, Version 2.0 + * See the LICENSE file that accompanied this code for further details. + * + * SPDX-FileCopyrightText: 2025 Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusdb + +import ( + "context" + "fmt" + "reflect" + + "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/dgraph-io/dgraph/v24/dql" + "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/schema" + "github.com/dgraph-io/dgraph/v24/x" + "github.com/hypermodeinc/modusdb/api/mutations" + "github.com/hypermodeinc/modusdb/api/utils" +) + +func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object T, + gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error { + t := reflect.TypeOf(object) + if t.Kind() != reflect.Struct { + return fmt.Errorf("expected struct, got %s", t.Kind()) + } + + fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, err := utils.GetFieldTags(t) + if err != nil { + return err + } + jsonTagToValue := utils.GetJsonTagToValues(object, fieldToJsonTags) + + nquads := make([]*api.NQuad, 0) + uniqueConstraintFound := false + for jsonName, value := range jsonTagToValue { + + reflectValueType := reflect.TypeOf(value) + var nquad *api.NQuad + + if jsonToReverseEdgeTags[jsonName] != "" { + if err := mutations.HandleReverseEdge(jsonName, reflectValueType, n.id, sch, jsonToReverseEdgeTags); err != nil { + return err + } + continue + } + if jsonName == "gid" { + uniqueConstraintFound = true + continue + } + + value, err = processStructValue(ctx, value, n) + if err != nil { + return err + } + + value, err = processPointerValue(ctx, value, n) + if err != nil { + return err + } + + nquad, u, err := mutations.CreateNQuadAndSchema(value, gid, jsonName, t, n.ID()) + if err != nil { + return err + } + + uniqueConstraintFound, err = utils.HandleConstraints(u, jsonToDbTags, jsonName, u.ValueType, uniqueConstraintFound) + if err != nil { + return err + } + + sch.Preds = append(sch.Preds, u) + nquads = append(nquads, nquad) + } + if !uniqueConstraintFound { + return fmt.Errorf(utils.NoUniqueConstr, t.Name()) + } + + sch.Types = append(sch.Types, &pb.TypeUpdate{ + TypeName: utils.AddNamespace(n.id, t.Name()), + Fields: sch.Preds, + }) + + val, err := utils.ValueToApiVal(t.Name()) + if err != nil { + return err + } + typeNquad := &api.NQuad{ + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: "dgraph.type", + ObjectValue: val, + } + nquads = append(nquads, typeNquad) + + *dms = append(*dms, &dql.Mutation{ + Set: nquads, + }) + + return nil +} + +func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { + return []*dql.Mutation{{ + Del: []*api.NQuad{ + { + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: x.Star, + ObjectValue: &api.Value{ + Val: &api.Value_DefaultVal{DefaultVal: x.Star}, + }, + }, + }, + }} +} diff --git a/api_mutation_helpers.go b/api_mutation_helpers.go new file mode 100644 index 0000000..d4b0cf4 --- /dev/null +++ b/api_mutation_helpers.go @@ -0,0 +1,123 @@ +package modusdb + +import ( + "context" + "fmt" + "reflect" + + "github.com/dgraph-io/dgraph/v24/dql" + "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/query" + "github.com/dgraph-io/dgraph/v24/schema" + "github.com/dgraph-io/dgraph/v24/worker" + "github.com/hypermodeinc/modusdb/api/utils" +) + +func processStructValue(ctx context.Context, value any, n *Namespace) (any, error) { + if reflect.TypeOf(value).Kind() == reflect.Struct { + value = reflect.ValueOf(value).Interface() + newGid, err := getUidOrMutate(ctx, n.db, n, value) + if err != nil { + return nil, err + } + return newGid, nil + } + return value, nil +} + +func processPointerValue(ctx context.Context, value any, n *Namespace) (any, error) { + reflectValueType := reflect.TypeOf(value) + if reflectValueType.Kind() == reflect.Pointer { + reflectValueType = reflectValueType.Elem() + if reflectValueType.Kind() == reflect.Struct { + value = reflect.ValueOf(value).Elem().Interface() + return processStructValue(ctx, value, n) + } + } + return value, nil +} + +func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { + gid, cf, err := GetUniqueConstraint[T](object) + if err != nil { + return 0, err + } + + dms := make([]*dql.Mutation, 0) + sch := &schema.ParsedSchema{} + err = generateSetDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + if err != nil { + return 0, err + } + + err = n.alterSchemaWithParsed(ctx, sch) + if err != nil { + return 0, err + } + if gid != 0 || cf != nil { + gid, err = getExistingObject(ctx, n, gid, cf, object) + if err != nil && err != utils.ErrNoObjFound { + return 0, err + } + if err == nil { + return gid, nil + } + } + + gid, err = db.z.nextUID() + if err != nil { + return 0, err + } + + dms = make([]*dql.Mutation, 0) + err = generateSetDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + if err != nil { + return 0, err + } + + err = applyDqlMutations(ctx, db, dms) + if err != nil { + return 0, err + } + + return gid, nil +} + +func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { + edges, err := query.ToDirectedEdges(dms, nil) + if err != nil { + return err + } + + if !db.isOpen { + return ErrClosedDB + } + + startTs, err := db.z.nextTs() + if err != nil { + return err + } + commitTs, err := db.z.nextTs() + if err != nil { + return err + } + + m := &pb.Mutations{ + GroupId: 1, + StartTs: startTs, + Edges: edges, + } + m.Edges, err = query.ExpandEdges(ctx, m) + if err != nil { + return fmt.Errorf("error expanding edges: %w", err) + } + + p := &pb.Proposal{Mutations: m, StartTs: startTs} + if err := worker.ApplyMutations(ctx, p); err != nil { + return err + } + + return worker.ApplyCommited(ctx, &pb.OracleDelta{ + Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, + }) +} diff --git a/api_query_helper.go b/api_query_execution.go similarity index 69% rename from api_query_helper.go rename to api_query_execution.go index 9ec211c..bf9cb23 100644 --- a/api_query_helper.go +++ b/api_query_execution.go @@ -14,6 +14,9 @@ import ( "encoding/json" "fmt" "reflect" + + "github.com/hypermodeinc/modusdb/api/query_gen" + "github.com/hypermodeinc/modusdb/api/utils" ) func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, T, error) { @@ -47,14 +50,15 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac obj T, withReverse bool, args ...R) (uint64, T, error) { t := reflect.TypeOf(obj) - fieldToJsonTags, jsonToDbTag, jsonToReverseEdgeTags, err := getFieldTags(t) + fieldToJsonTags, jsonToDbTag, jsonToReverseEdgeTags, err := utils.GetFieldTags(t) if err != nil { return 0, obj, err } readFromQuery := "" if withReverse { for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { - readFromQuery += fmt.Sprintf(reverseEdgeQuery, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + readFromQuery += fmt.Sprintf(query_gen.ReverseEdgeQuery, + utils.GetPredicateName(t.Name(), jsonTag), reverseEdgeTag) } } @@ -62,14 +66,15 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac var query string gid, ok := any(args[0]).(uint64) if ok { - query = formatObjQuery(buildUidQuery(gid), readFromQuery) + query = query_gen.FormatObjQuery(query_gen.BuildUidQuery(gid), readFromQuery) } else if cf, ok = any(args[0]).(ConstrainedField); ok { - query = formatObjQuery(buildEqQuery(getPredicateName(t.Name(), cf.Key), cf.Value), readFromQuery) + query = query_gen.FormatObjQuery(query_gen.BuildEqQuery(utils.GetPredicateName(t.Name(), + cf.Key), cf.Value), readFromQuery) } else { return 0, obj, fmt.Errorf("invalid unique field type") } - if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" { + if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].Constraint == "" { return 0, obj, fmt.Errorf("constraint not defined for field %s", cf.Key) } @@ -78,7 +83,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac return 0, obj, err } - dynamicType := createDynamicStruct(t, fieldToJsonTags, 1) + dynamicType := utils.CreateDynamicStruct(t, fieldToJsonTags, 1) dynamicInstance := reflect.New(dynamicType).Interface() @@ -95,37 +100,22 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac // Check if we have at least one object in the response if len(result.Obj) == 0 { - return 0, obj, ErrNoObjFound - } - - // Map the dynamic struct to the final type T - finalObject := reflect.New(t).Interface() - gid, err = mapDynamicToFinal(result.Obj[0], finalObject, false) - if err != nil { - return 0, obj, err - } - - if typedPtr, ok := finalObject.(*T); ok { - return gid, *typedPtr, nil - } - - if dirType, ok := finalObject.(T); ok { - return gid, dirType, nil + return 0, obj, utils.ErrNoObjFound } - return 0, obj, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) + return utils.ConvertDynamicToTyped[T](result.Obj[0], t) } func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryParams, withReverse bool) ([]uint64, []T, error) { var obj T t := reflect.TypeOf(obj) - fieldToJsonTags, _, jsonToReverseEdgeTags, err := getFieldTags(t) + fieldToJsonTags, _, jsonToReverseEdgeTags, err := utils.GetFieldTags(t) if err != nil { return nil, nil, err } - var filterQueryFunc QueryFunc = func() string { + var filterQueryFunc query_gen.QueryFunc = func() string { return "" } var paginationAndSorting string @@ -146,18 +136,18 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar readFromQuery := "" if withReverse { for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { - readFromQuery += fmt.Sprintf(reverseEdgeQuery, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + readFromQuery += fmt.Sprintf(query_gen.ReverseEdgeQuery, utils.GetPredicateName(t.Name(), jsonTag), reverseEdgeTag) } } - query := formatObjsQuery(t.Name(), filterQueryFunc, paginationAndSorting, readFromQuery) + query := query_gen.FormatObjsQuery(t.Name(), filterQueryFunc, paginationAndSorting, readFromQuery) resp, err := n.queryWithLock(ctx, query) if err != nil { return nil, nil, err } - dynamicType := createDynamicStruct(t, fieldToJsonTags, 1) + dynamicType := utils.CreateDynamicStruct(t, fieldToJsonTags, 1) var result struct { Objs []any `json:"objs"` @@ -181,25 +171,30 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar return nil, nil, err } - var gids []uint64 - var objs []T - for _, obj := range result.Objs { - finalObject := reflect.New(t).Interface() - gid, err := mapDynamicToFinal(obj, finalObject, false) + gids := make([]uint64, len(result.Objs)) + objs := make([]T, len(result.Objs)) + for i, obj := range result.Objs { + gid, typedObj, err := utils.ConvertDynamicToTyped[T](obj, t) if err != nil { return nil, nil, err } - - if typedPtr, ok := finalObject.(*T); ok { - gids = append(gids, gid) - objs = append(objs, *typedPtr) - } else if dirType, ok := finalObject.(T); ok { - gids = append(gids, gid) - objs = append(objs, dirType) - } else { - return nil, nil, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) - } + gids[i] = gid + objs[i] = typedObj } return gids, objs, nil } + +func getExistingObject[T any](ctx context.Context, n *Namespace, gid uint64, cf *ConstrainedField, + object T) (uint64, error) { + var err error + if gid != 0 { + gid, _, err = getByGidWithObject[T](ctx, n, gid, object) + } else if cf != nil { + gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) + } + if err != nil { + return 0, err + } + return gid, nil +} diff --git a/api_reflect.go b/api_reflect.go index 5a32b01..029a223 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -12,203 +12,17 @@ package modusdb import ( "fmt" "reflect" - "strconv" - "strings" -) - -type dbTag struct { - constraint string -} - -func getFieldTags(t reflect.Type) (fieldToJsonTags map[string]string, - jsonToDbTags map[string]*dbTag, jsonToReverseEdgeTags map[string]string, err error) { - - fieldToJsonTags = make(map[string]string) - jsonToDbTags = make(map[string]*dbTag) - jsonToReverseEdgeTags = make(map[string]string) - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - jsonTag := field.Tag.Get("json") - if jsonTag == "" { - return nil, nil, nil, fmt.Errorf("field %s has no json tag", field.Name) - } - jsonName := strings.Split(jsonTag, ",")[0] - fieldToJsonTags[field.Name] = jsonName - - reverseEdgeTag := field.Tag.Get("readFrom") - if reverseEdgeTag != "" { - typeAndField := strings.Split(reverseEdgeTag, ",") - if len(typeAndField) != 2 { - return nil, nil, nil, fmt.Errorf(`field %s has invalid readFrom tag, - expected format is type=,field=`, field.Name) - } - t := strings.Split(typeAndField[0], "=")[1] - f := strings.Split(typeAndField[1], "=")[1] - jsonToReverseEdgeTags[jsonName] = getPredicateName(t, f) - } - - dbConstraintsTag := field.Tag.Get("db") - if dbConstraintsTag != "" { - jsonToDbTags[jsonName] = &dbTag{} - dbTagsSplit := strings.Split(dbConstraintsTag, ",") - for _, dbTag := range dbTagsSplit { - split := strings.Split(dbTag, "=") - if split[0] == "constraint" { - jsonToDbTags[jsonName].constraint = split[1] - } - } - } - } - return fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, nil -} - -func getJsonTagToValues(object any, fieldToJsonTags map[string]string) map[string]any { - values := make(map[string]any) - v := reflect.ValueOf(object) - for v.Kind() == reflect.Ptr { - v = v.Elem() - } - for fieldName, jsonName := range fieldToJsonTags { - fieldValue := v.FieldByName(fieldName) - values[jsonName] = fieldValue.Interface() - } - return values -} - -func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, depth int) reflect.Type { - fields := make([]reflect.StructField, 0, len(fieldToJsonTags)) - for fieldName, jsonName := range fieldToJsonTags { - field, _ := t.FieldByName(fieldName) - if fieldName != "Gid" { - if field.Type.Kind() == reflect.Struct { - if depth <= 1 { - nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type) - nestedType := createDynamicStruct(field.Type, nestedFieldToJsonTags, depth+1) - fields = append(fields, reflect.StructField{ - Name: field.Name, - Type: nestedType, - Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), - }) - } - } else if field.Type.Kind() == reflect.Ptr && - field.Type.Elem().Kind() == reflect.Struct { - nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type.Elem()) - nestedType := createDynamicStruct(field.Type.Elem(), nestedFieldToJsonTags, depth+1) - fields = append(fields, reflect.StructField{ - Name: field.Name, - Type: reflect.PointerTo(nestedType), - Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), - }) - } else if field.Type.Kind() == reflect.Slice && - field.Type.Elem().Kind() == reflect.Struct { - nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type.Elem()) - nestedType := createDynamicStruct(field.Type.Elem(), nestedFieldToJsonTags, depth+1) - fields = append(fields, reflect.StructField{ - Name: field.Name, - Type: reflect.SliceOf(nestedType), - Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), - }) - } else { - fields = append(fields, reflect.StructField{ - Name: field.Name, - Type: field.Type, - Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), - }) - } - - } - } - fields = append(fields, reflect.StructField{ - Name: "Gid", - Type: reflect.TypeOf(""), - Tag: reflect.StructTag(`json:"gid"`), - }, reflect.StructField{ - Name: "DgraphType", - Type: reflect.TypeOf([]string{}), - Tag: reflect.StructTag(`json:"dgraph.type"`), - }) - return reflect.StructOf(fields) -} - -func mapDynamicToFinal(dynamic any, final any, isNested bool) (uint64, error) { - vFinal := reflect.ValueOf(final).Elem() - vDynamic := reflect.ValueOf(dynamic).Elem() - - gid := uint64(0) - - for i := 0; i < vDynamic.NumField(); i++ { - - dynamicField := vDynamic.Type().Field(i) - dynamicFieldType := dynamicField.Type - dynamicValue := vDynamic.Field(i) - - var finalField reflect.Value - if dynamicField.Name == "Gid" { - finalField = vFinal.FieldByName("Gid") - gidStr := dynamicValue.String() - gid, _ = strconv.ParseUint(gidStr, 0, 64) - } else if dynamicField.Name == "DgraphType" { - fieldArrInterface := dynamicValue.Interface() - fieldArr, ok := fieldArrInterface.([]string) - if ok { - if len(fieldArr) == 0 { - if !isNested { - return 0, ErrNoObjFound - } else { - continue - } - } - } else { - return 0, fmt.Errorf("DgraphType field should be an array of strings") - } - } else { - finalField = vFinal.FieldByName(dynamicField.Name) - } - if dynamicFieldType.Kind() == reflect.Struct { - _, err := mapDynamicToFinal(dynamicValue.Addr().Interface(), finalField.Addr().Interface(), true) - if err != nil { - return 0, err - } - } else if dynamicFieldType.Kind() == reflect.Ptr && - dynamicFieldType.Elem().Kind() == reflect.Struct { - // if field is a pointer, find if the underlying is a struct - _, err := mapDynamicToFinal(dynamicValue.Interface(), finalField.Interface(), true) - if err != nil { - return 0, err - } - } else if dynamicFieldType.Kind() == reflect.Slice && - dynamicFieldType.Elem().Kind() == reflect.Struct { - for j := 0; j < dynamicValue.Len(); j++ { - sliceElem := dynamicValue.Index(j).Addr().Interface() - finalSliceElem := reflect.New(finalField.Type().Elem()).Elem() - _, err := mapDynamicToFinal(sliceElem, finalSliceElem.Addr().Interface(), true) - if err != nil { - return 0, err - } - finalField.Set(reflect.Append(finalField, finalSliceElem)) - } - } else { - if finalField.IsValid() && finalField.CanSet() { - // if field name is gid, convert it to uint64 - if dynamicField.Name == "Gid" { - finalField.SetUint(gid) - } else { - finalField.Set(dynamicValue) - } - } - } - } - return gid, nil -} + "github.com/hypermodeinc/modusdb/api/utils" +) -func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { +func GetUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { t := reflect.TypeOf(object) - fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) + fieldToJsonTags, jsonToDbTags, _, err := utils.GetFieldTags(t) if err != nil { return 0, nil, err } - jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) + jsonTagToValue := utils.GetJsonTagToValues(object, fieldToJsonTags) for jsonName, value := range jsonTagToValue { if jsonName == "gid" { @@ -220,7 +34,7 @@ func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { return gid, nil, nil } } - if jsonToDbTags[jsonName] != nil && isValidUniqueIndex(jsonToDbTags[jsonName].constraint) { + if jsonToDbTags[jsonName] != nil && IsValidUniqueIndex(jsonToDbTags[jsonName].Constraint) { // check if value is zero or nil if value == reflect.Zero(reflect.TypeOf(value)).Interface() || value == nil { continue @@ -232,9 +46,9 @@ func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { } } - return 0, nil, fmt.Errorf(NoUniqueConstr, t.Name()) + return 0, nil, fmt.Errorf(utils.NoUniqueConstr, t.Name()) } -func isValidUniqueIndex(name string) bool { +func IsValidUniqueIndex(name string) bool { return name == "unique" } diff --git a/api_test.go b/api_test.go index f7f15e0..cb0ca7a 100644 --- a/api_test.go +++ b/api_test.go @@ -17,6 +17,7 @@ import ( "github.com/stretchr/testify/require" "github.com/hypermodeinc/modusdb" + "github.com/hypermodeinc/modusdb/api/utils" ) type User struct { @@ -767,7 +768,7 @@ func TestNestedObjectMutationWithBadType(t *testing.T) { _, _, err = modusdb.Create(db, branch, db1.ID()) require.Error(t, err) - require.Equal(t, fmt.Sprintf(modusdb.NoUniqueConstr, "BadProject"), err.Error()) + require.Equal(t, fmt.Sprintf(utils.NoUniqueConstr, "BadProject"), err.Error()) proj := BadProject{ Name: "P", @@ -776,7 +777,7 @@ func TestNestedObjectMutationWithBadType(t *testing.T) { _, _, err = modusdb.Create(db, proj, db1.ID()) require.Error(t, err) - require.Equal(t, fmt.Sprintf(modusdb.NoUniqueConstr, "BadProject"), err.Error()) + require.Equal(t, fmt.Sprintf(utils.NoUniqueConstr, "BadProject"), err.Error()) } diff --git a/api_types.go b/api_types.go index 2384f35..90d0af1 100644 --- a/api_types.go +++ b/api_types.go @@ -11,22 +11,12 @@ package modusdb import ( "context" - "encoding/binary" "fmt" "strings" - "time" - "github.com/dgraph-io/dgo/v240/protos/api" - "github.com/dgraph-io/dgraph/v24/protos/pb" - "github.com/dgraph-io/dgraph/v24/types" "github.com/dgraph-io/dgraph/v24/x" - "github.com/twpayne/go-geom" - "github.com/twpayne/go-geom/encoding/wkb" -) - -var ( - ErrNoObjFound = fmt.Errorf("no object found") - NoUniqueConstr = "unique constraint not defined for any field on type %s" + "github.com/hypermodeinc/modusdb/api/query_gen" + "github.com/hypermodeinc/modusdb/api/utils" ) type UniqueField interface { @@ -113,137 +103,108 @@ func getDefaultNamespace(db *DB, ns ...uint64) (context.Context, *Namespace, err return ctx, n, nil } -func valueToPosting_ValType(v any) (pb.Posting_ValType, error) { - switch v.(type) { - case string: - return pb.Posting_STRING, nil - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32: - return pb.Posting_INT, nil - case uint64: - return pb.Posting_UID, nil - case bool: - return pb.Posting_BOOL, nil - case float32, float64: - return pb.Posting_FLOAT, nil - case []byte: - return pb.Posting_BINARY, nil - case time.Time: - return pb.Posting_DATETIME, nil - case geom.Point: - return pb.Posting_GEO, nil - case []float32, []float64: - return pb.Posting_VFLOAT, nil - default: - return pb.Posting_DEFAULT, fmt.Errorf("unsupported type %T", v) - } -} - -func valueToApiVal(v any) (*api.Value, error) { - switch val := v.(type) { - case string: - return &api.Value{Val: &api.Value_StrVal{StrVal: val}}, nil - case int: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int8: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int16: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int32: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int64: - return &api.Value{Val: &api.Value_IntVal{IntVal: val}}, nil - case uint8: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case uint16: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case uint32: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case uint64: - return &api.Value{Val: &api.Value_UidVal{UidVal: val}}, nil - case bool: - return &api.Value{Val: &api.Value_BoolVal{BoolVal: val}}, nil - case float32: - return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: float64(val)}}, nil - case float64: - return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}}, nil - case []float32: - return &api.Value{Val: &api.Value_Vfloat32Val{ - Vfloat32Val: types.FloatArrayAsBytes(val)}}, nil - case []float64: - float32Slice := make([]float32, len(val)) - for i, v := range val { - float32Slice[i] = float32(v) - } - return &api.Value{Val: &api.Value_Vfloat32Val{ - Vfloat32Val: types.FloatArrayAsBytes(float32Slice)}}, nil - case []byte: - return &api.Value{Val: &api.Value_BytesVal{BytesVal: val}}, nil - case time.Time: - bytes, err := val.MarshalBinary() - if err != nil { - return nil, err - } - return &api.Value{Val: &api.Value_DateVal{DateVal: bytes}}, nil - case geom.Point: - bytes, err := wkb.Marshal(&val, binary.LittleEndian) - if err != nil { - return nil, err - } - return &api.Value{Val: &api.Value_GeoVal{GeoVal: bytes}}, nil - case uint: - return &api.Value{Val: &api.Value_DefaultVal{DefaultVal: fmt.Sprint(v)}}, nil - default: - return nil, fmt.Errorf("unsupported type %T", v) - } -} - -func filterToQueryFunc(typeName string, f Filter) QueryFunc { +func filterToQueryFunc(typeName string, f Filter) query_gen.QueryFunc { // Handle logical operators first if f.And != nil { - return And(filterToQueryFunc(typeName, *f.And)) + return query_gen.And(filterToQueryFunc(typeName, *f.And)) } if f.Or != nil { - return Or(filterToQueryFunc(typeName, *f.Or)) + return query_gen.Or(filterToQueryFunc(typeName, *f.Or)) } if f.Not != nil { - return Not(filterToQueryFunc(typeName, *f.Not)) + return query_gen.Not(filterToQueryFunc(typeName, *f.Not)) } // Handle field predicates if f.String.Equals != "" { - return buildEqQuery(getPredicateName(typeName, f.Field), f.String.Equals) + return query_gen.BuildEqQuery(utils.GetPredicateName(typeName, f.Field), f.String.Equals) } if len(f.String.AllOfTerms) != 0 { - return buildAllOfTermsQuery(getPredicateName(typeName, f.Field), strings.Join(f.String.AllOfTerms, " ")) + return query_gen.BuildAllOfTermsQuery(utils.GetPredicateName(typeName, + f.Field), strings.Join(f.String.AllOfTerms, " ")) } if len(f.String.AnyOfTerms) != 0 { - return buildAnyOfTermsQuery(getPredicateName(typeName, f.Field), strings.Join(f.String.AnyOfTerms, " ")) + return query_gen.BuildAnyOfTermsQuery(utils.GetPredicateName(typeName, + f.Field), strings.Join(f.String.AnyOfTerms, " ")) } if len(f.String.AllOfText) != 0 { - return buildAllOfTextQuery(getPredicateName(typeName, f.Field), strings.Join(f.String.AllOfText, " ")) + return query_gen.BuildAllOfTextQuery(utils.GetPredicateName(typeName, + f.Field), strings.Join(f.String.AllOfText, " ")) } if len(f.String.AnyOfText) != 0 { - return buildAnyOfTextQuery(getPredicateName(typeName, f.Field), strings.Join(f.String.AnyOfText, " ")) + return query_gen.BuildAnyOfTextQuery(utils.GetPredicateName(typeName, + f.Field), strings.Join(f.String.AnyOfText, " ")) } if f.String.RegExp != "" { - return buildRegExpQuery(getPredicateName(typeName, f.Field), f.String.RegExp) + return query_gen.BuildRegExpQuery(utils.GetPredicateName(typeName, + f.Field), f.String.RegExp) } if f.String.LessThan != "" { - return buildLtQuery(getPredicateName(typeName, f.Field), f.String.LessThan) + return query_gen.BuildLtQuery(utils.GetPredicateName(typeName, + f.Field), f.String.LessThan) } if f.String.LessOrEqual != "" { - return buildLeQuery(getPredicateName(typeName, f.Field), f.String.LessOrEqual) + return query_gen.BuildLeQuery(utils.GetPredicateName(typeName, + f.Field), f.String.LessOrEqual) } if f.String.GreaterThan != "" { - return buildGtQuery(getPredicateName(typeName, f.Field), f.String.GreaterThan) + return query_gen.BuildGtQuery(utils.GetPredicateName(typeName, + f.Field), f.String.GreaterThan) } if f.String.GreaterOrEqual != "" { - return buildGeQuery(getPredicateName(typeName, f.Field), f.String.GreaterOrEqual) + return query_gen.BuildGeQuery(utils.GetPredicateName(typeName, + f.Field), f.String.GreaterOrEqual) } if f.Vector.SimilarTo != nil { - return buildSimilarToQuery(getPredicateName(typeName, f.Field), f.Vector.TopK, f.Vector.SimilarTo) + return query_gen.BuildSimilarToQuery(utils.GetPredicateName(typeName, + f.Field), f.Vector.TopK, f.Vector.SimilarTo) } // Return empty query if no conditions match return func() string { return "" } } + +// Helper function to combine multiple filters +func filtersToQueryFunc(typeName string, filter Filter) query_gen.QueryFunc { + return filterToQueryFunc(typeName, filter) +} + +func paginationToQueryString(p Pagination) string { + paginationStr := "" + if p.Limit > 0 { + paginationStr += ", " + fmt.Sprintf("first: %d", p.Limit) + } + if p.Offset > 0 { + paginationStr += ", " + fmt.Sprintf("offset: %d", p.Offset) + } else if p.After != "" { + paginationStr += ", " + fmt.Sprintf("after: %s", p.After) + } + if paginationStr == "" { + return "" + } + return paginationStr +} + +func sortingToQueryString(typeName string, s Sorting) string { + if s.OrderAscField == "" && s.OrderDescField == "" { + return "" + } + + var parts []string + first, second := s.OrderDescField, s.OrderAscField + firstOp, secondOp := "orderdesc", "orderasc" + + if !s.OrderDescFirst { + first, second = s.OrderAscField, s.OrderDescField + firstOp, secondOp = "orderasc", "orderdesc" + } + + if first != "" { + parts = append(parts, fmt.Sprintf("%s: %s", firstOp, utils.GetPredicateName(typeName, first))) + } + if second != "" { + parts = append(parts, fmt.Sprintf("%s: %s", secondOp, utils.GetPredicateName(typeName, second))) + } + + return ", " + strings.Join(parts, ", ") +}