diff --git a/api.go b/api.go index 88a42b2..7fcdccd 100644 --- a/api.go +++ b/api.go @@ -1,91 +1,127 @@ package modusdb import ( - "context" "fmt" - "reflect" - "github.com/dgraph-io/dgraph/v24/x" + "github.com/dgraph-io/dgraph/v24/dql" + "github.com/dgraph-io/dgraph/v24/schema" ) -type ModusDbOption func(*modusDbOptions) - -type modusDbOptions struct { - namespace uint64 -} - -func WithNamespace(namespace uint64) ModusDbOption { - return func(o *modusDbOptions) { - o.namespace = namespace +func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) { + db.mutex.Lock() + defer db.mutex.Unlock() + if len(ns) > 1 { + return 0, object, fmt.Errorf("only one namespace is allowed") + } + ctx, n, err := getDefaultNamespace(db, ns...) + if err != nil { + return 0, object, err } -} -func getDefaultNamespace(db *DB, ns ...uint64) (context.Context, *Namespace, error) { - dbOpts := &modusDbOptions{ - namespace: db.defaultNamespace.ID(), + gid, err := db.z.nextUID() + if err != nil { + return 0, object, err } - for _, ns := range ns { - WithNamespace(ns)(dbOpts) + + dms := make([]*dql.Mutation, 0) + sch := &schema.ParsedSchema{} + err = generateCreateDqlMutationsAndSchema[T](ctx, n, *object, gid, &dms, sch) + if err != nil { + return 0, object, err } - n, err := db.getNamespaceWithLock(dbOpts.namespace) + err = n.alterSchemaWithParsed(ctx, sch) if err != nil { - return nil, nil, err + return 0, object, err } - ctx := context.Background() - ctx = x.AttachNamespace(ctx, n.ID()) + err = applyDqlMutations(ctx, db, dms) + if err != nil { + return 0, object, err + } - return ctx, n, nil + return getByGid[T](ctx, n, gid) } -func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) { +func Upsert[T any](db *DB, object *T, ns ...uint64) (uint64, *T, bool, error) { + + var wasFound bool db.mutex.Lock() defer db.mutex.Unlock() if len(ns) > 1 { - return 0, object, fmt.Errorf("only one namespace is allowed") + return 0, object, false, fmt.Errorf("only one namespace is allowed") } + if object == nil { + return 0, nil, false, fmt.Errorf("object is nil") + } + ctx, n, err := getDefaultNamespace(db, ns...) if err != nil { - return 0, object, err + return 0, object, false, err } - gid, err := db.z.nextUID() + gid, cf, err := getUniqueConstraint[T](*object) if err != nil { - return 0, object, err + return 0, nil, false, err } - dms, sch, err := generateCreateDqlMutationsAndSchema(n, object, gid) + dms := make([]*dql.Mutation, 0) + sch := &schema.ParsedSchema{} + err = generateCreateDqlMutationsAndSchema[T](ctx, n, *object, gid, &dms, sch) if err != nil { - return 0, object, err + return 0, nil, false, err } - ctx = x.AttachNamespace(ctx, n.ID()) - err = n.alterSchemaWithParsed(ctx, sch) if err != nil { - return 0, object, err + return 0, nil, false, err } - err = applyDqlMutations(ctx, db, dms) - if err != nil { - return 0, object, err + if gid != 0 { + gid, _, err = getByGidWithObject[T](ctx, n, gid, *object) + if err != nil && err != ErrNoObjFound { + return 0, nil, false, err + } + wasFound = err == nil + } else if cf != nil { + gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, *object) + if err != nil && err != ErrNoObjFound { + return 0, nil, false, err + } + wasFound = err == nil + } + if gid == 0 { + gid, err = db.z.nextUID() + if err != nil { + return 0, nil, false, err + } } - v := reflect.ValueOf(object).Elem() + dms = make([]*dql.Mutation, 0) + err = generateCreateDqlMutationsAndSchema[T](ctx, n, *object, gid, &dms, sch) + if err != nil { + return 0, nil, false, err + } - gidField := v.FieldByName("Gid") + err = applyDqlMutations(ctx, db, dms) + if err != nil { + return 0, nil, false, err + } - if gidField.IsValid() && gidField.CanSet() && gidField.Kind() == reflect.Uint64 { - gidField.SetUint(gid) + gid, object, err = getByGid[T](ctx, n, gid) + if err != nil { + return 0, nil, false, err } - return gid, object, nil + return gid, object, wasFound, nil } func Get[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, error) { db.mutex.Lock() defer db.mutex.Unlock() + if len(ns) > 1 { + return 0, nil, fmt.Errorf("only one namespace is allowed") + } ctx, n, err := getDefaultNamespace(db, ns...) if err != nil { return 0, nil, err @@ -104,6 +140,9 @@ func Get[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, error) { db.mutex.Lock() defer db.mutex.Unlock() + if len(ns) > 1 { + return 0, nil, fmt.Errorf("only one namespace is allowed") + } ctx, n, err := getDefaultNamespace(db, ns...) if err != nil { return 0, nil, err @@ -125,7 +164,19 @@ func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, } if cf, ok := any(uniqueField).(ConstrainedField); ok { - return getByConstrainedField[T](ctx, n, cf) + uid, obj, err := getByConstrainedField[T](ctx, n, cf) + if err != nil { + return 0, nil, err + } + + dms := generateDeleteDqlMutations(n, uid) + + err = applyDqlMutations(ctx, db, dms) + if err != nil { + return 0, nil, err + } + + return uid, obj, nil } return 0, nil, fmt.Errorf("invalid unique field type") diff --git a/api_dql.go b/api_dql.go new file mode 100644 index 0000000..a2b0e65 --- /dev/null +++ b/api_dql.go @@ -0,0 +1,41 @@ +package modusdb + +import "fmt" + +type QueryFunc func() string + +const ( + objQuery = ` + { + obj(%s) { + uid + expand(_all_) { + uid + expand(_all_) + dgraph.type + } + dgraph.type + %s + } + } + ` + + funcUid = `func: uid(%d)` + funcEq = `func: eq(%s, %s)` +) + +func buildUidQuery(gid uint64) QueryFunc { + return func() string { + return fmt.Sprintf(funcUid, gid) + } +} + +func buildEqQuery(key, value any) QueryFunc { + return func() string { + return fmt.Sprintf(funcEq, key, value) + } +} + +func formatObjQuery(qf QueryFunc, extraFields string) string { + return fmt.Sprintf(objQuery, qf(), extraFields) +} diff --git a/api_helper.go b/api_helper.go deleted file mode 100644 index 0cda75b..0000000 --- a/api_helper.go +++ /dev/null @@ -1,301 +0,0 @@ -package modusdb - -import ( - "context" - "encoding/binary" - "encoding/json" - "fmt" - "reflect" - "time" - - "github.com/dgraph-io/dgo/v240/protos/api" - "github.com/dgraph-io/dgraph/v24/dql" - "github.com/dgraph-io/dgraph/v24/protos/pb" - "github.com/dgraph-io/dgraph/v24/query" - "github.com/dgraph-io/dgraph/v24/schema" - "github.com/dgraph-io/dgraph/v24/worker" - "github.com/dgraph-io/dgraph/v24/x" - "github.com/twpayne/go-geom" - "github.com/twpayne/go-geom/encoding/wkb" -) - -func getPredicateName(typeName, fieldName string) string { - return fmt.Sprint(typeName, ".", fieldName) -} - -func addNamespace(ns uint64, pred string) string { - return x.NamespaceAttr(ns, pred) -} - -func valueToPosting_ValType(v any) (pb.Posting_ValType, error) { - switch v.(type) { - case string: - return pb.Posting_STRING, nil - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return pb.Posting_INT, nil - case bool: - return pb.Posting_BOOL, nil - case float32, float64: - return pb.Posting_FLOAT, nil - case []byte: - return pb.Posting_BINARY, nil - case time.Time: - return pb.Posting_DATETIME, nil - case geom.Point: - return pb.Posting_GEO, nil - case []float32, []float64: - return pb.Posting_VFLOAT, nil - default: - return pb.Posting_DEFAULT, fmt.Errorf("unsupported type %T", v) - } -} - -func valueToValType(v any) (*api.Value, error) { - switch val := v.(type) { - case string: - return &api.Value{Val: &api.Value_StrVal{StrVal: val}}, nil - case int: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int8: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int16: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int32: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int64: - return &api.Value{Val: &api.Value_IntVal{IntVal: val}}, nil - case uint8: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case uint16: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case uint32: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case bool: - return &api.Value{Val: &api.Value_BoolVal{BoolVal: val}}, nil - case float32: - return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: float64(val)}}, nil - case float64: - return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}}, nil - case []byte: - return &api.Value{Val: &api.Value_BytesVal{BytesVal: val}}, nil - case time.Time: - bytes, err := val.MarshalBinary() - if err != nil { - return nil, err - } - return &api.Value{Val: &api.Value_DateVal{DateVal: bytes}}, nil - case geom.Point: - bytes, err := wkb.Marshal(&val, binary.LittleEndian) - if err != nil { - return nil, err - } - return &api.Value{Val: &api.Value_GeoVal{GeoVal: bytes}}, nil - case uint, uint64: - return &api.Value{Val: &api.Value_DefaultVal{DefaultVal: fmt.Sprint(v)}}, nil - default: - return nil, fmt.Errorf("unsupported type %T", v) - } -} - -func generateCreateDqlMutationsAndSchema[T any](n *Namespace, object *T, - gid uint64) ([]*dql.Mutation, *schema.ParsedSchema, error) { - t := reflect.TypeOf(*object) - if t.Kind() != reflect.Struct { - return nil, nil, fmt.Errorf("expected struct, got %s", t.Kind()) - } - - jsonFields, dbFields, _, err := getFieldTags(t) - if err != nil { - return nil, nil, err - } - values := getFieldValues(object, jsonFields) - sch := &schema.ParsedSchema{} - - nquads := make([]*api.NQuad, 0) - for jsonName, value := range values { - if jsonName == "gid" { - continue - } - valType, err := valueToPosting_ValType(value) - if err != nil { - return nil, nil, err - } - u := &pb.SchemaUpdate{ - Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), - ValueType: valType, - } - if dbFields[jsonName] != nil && dbFields[jsonName].constraint == "unique" { - u.Directive = pb.SchemaUpdate_INDEX - u.Tokenizer = []string{"exact"} - } - sch.Preds = append(sch.Preds, u) - val, err := valueToValType(value) - if err != nil { - return nil, nil, err - } - nquad := &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: getPredicateName(t.Name(), jsonName), - ObjectValue: val, - } - nquads = append(nquads, nquad) - } - sch.Types = append(sch.Types, &pb.TypeUpdate{ - TypeName: addNamespace(n.id, t.Name()), - Fields: sch.Preds, - }) - - val, err := valueToValType(t.Name()) - if err != nil { - return nil, nil, err - } - nquad := &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: "dgraph.type", - ObjectValue: val, - } - nquads = append(nquads, nquad) - - dms := make([]*dql.Mutation, 0) - dms = append(dms, &dql.Mutation{ - Set: nquads, - }) - - return dms, sch, nil -} - -func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { - return []*dql.Mutation{{ - Del: []*api.NQuad{ - { - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: x.Star, - ObjectValue: &api.Value{ - Val: &api.Value_DefaultVal{DefaultVal: x.Star}, - }, - }, - }, - }} -} - -func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, error) { - query := fmt.Sprintf(` - { - obj(func: uid(%d)) { - uid - expand(_all_) - dgraph.type - } - } - `, gid) - - return executeGet[T](ctx, n, query, nil) -} - -func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) { - var obj T - - t := reflect.TypeOf(obj) - query := fmt.Sprintf(` - { - obj(func: eq(%s, %s)) { - uid - expand(_all_) - dgraph.type - } - } - `, getPredicateName(t.Name(), cf.Key), cf.Value) - - return executeGet[T](ctx, n, query, &cf) -} - -func executeGet[T any](ctx context.Context, n *Namespace, query string, cf *ConstrainedField) (uint64, *T, error) { - var obj T - - t := reflect.TypeOf(obj) - - jsonFields, dbTags, _, err := getFieldTags(t) - if err != nil { - return 0, nil, err - } - - if cf != nil && dbTags[cf.Key].constraint == "" { - return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key) - } - - resp, err := n.queryWithLock(ctx, query) - if err != nil { - return 0, nil, err - } - - dynamicType := createDynamicStruct(t, jsonFields) - - dynamicInstance := reflect.New(dynamicType).Interface() - - var result struct { - Obj []any `json:"obj"` - } - - result.Obj = append(result.Obj, dynamicInstance) - - // Unmarshal the JSON response into the dynamic struct - if err := json.Unmarshal(resp.Json, &result); err != nil { - return 0, nil, err - } - - // Check if we have at least one object in the response - if len(result.Obj) == 0 { - return 0, nil, ErrNoObjFound - } - - // Map the dynamic struct to the final type T - finalObject := reflect.New(t).Interface() - gid, err := mapDynamicToFinal(result.Obj[0], finalObject) - if err != nil { - return 0, nil, err - } - - return gid, finalObject.(*T), nil -} - -func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { - edges, err := query.ToDirectedEdges(dms, nil) - if err != nil { - return err - } - - if !db.isOpen { - return ErrClosedDB - } - - startTs, err := db.z.nextTs() - if err != nil { - return err - } - commitTs, err := db.z.nextTs() - if err != nil { - return err - } - - m := &pb.Mutations{ - GroupId: 1, - StartTs: startTs, - Edges: edges, - } - m.Edges, err = query.ExpandEdges(ctx, m) - if err != nil { - return fmt.Errorf("error expanding edges: %w", err) - } - - p := &pb.Proposal{Mutations: m, StartTs: startTs} - if err := worker.ApplyMutations(ctx, p); err != nil { - return err - } - - return worker.ApplyCommited(ctx, &pb.OracleDelta{ - Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, - }) -} diff --git a/api_mutate_helper.go b/api_mutate_helper.go new file mode 100644 index 0000000..7957e0a --- /dev/null +++ b/api_mutate_helper.go @@ -0,0 +1,240 @@ +package modusdb + +import ( + "context" + "fmt" + "reflect" + + "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/dgraph-io/dgraph/v24/dql" + "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/query" + "github.com/dgraph-io/dgraph/v24/schema" + "github.com/dgraph-io/dgraph/v24/worker" + "github.com/dgraph-io/dgraph/v24/x" +) + +func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object T, + gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error { + t := reflect.TypeOf(object) + if t.Kind() != reflect.Struct { + return fmt.Errorf("expected struct, got %s", t.Kind()) + } + + fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, err := getFieldTags(t) + if err != nil { + return err + } + jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) + + nquads := make([]*api.NQuad, 0) + uniqueConstraintFound := false + for jsonName, value := range jsonTagToValue { + if jsonToReverseEdgeTags[jsonName] != "" { + continue + } + if jsonName == "gid" { + uniqueConstraintFound = true + continue + } + var val *api.Value + var valType pb.Posting_ValType + + reflectValueType := reflect.TypeOf(value) + var nquad *api.NQuad + + if reflectValueType.Kind() == reflect.Struct { + value = reflect.ValueOf(value).Interface() + newGid, err := getUidOrMutate(ctx, n.db, n, value) + if err != nil { + return err + } + value = newGid + } else if reflectValueType.Kind() == reflect.Pointer { + // dereference the pointer + reflectValueType = reflectValueType.Elem() + if reflectValueType.Kind() == reflect.Struct { + // convert value to pointer, and then dereference + value = reflect.ValueOf(value).Elem().Interface() + newGid, err := getUidOrMutate(ctx, n.db, n, value) + if err != nil { + return err + } + value = newGid + } + } + valType, err = valueToPosting_ValType(value) + if err != nil { + return err + } + val, err = valueToApiVal(value) + if err != nil { + return err + } + + nquad = &api.NQuad{ + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: getPredicateName(t.Name(), jsonName), + } + + if valType == pb.Posting_UID { + nquad.ObjectId = fmt.Sprint(value) + } else { + nquad.ObjectValue = val + } + + u := &pb.SchemaUpdate{ + Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), + ValueType: valType, + } + if jsonToDbTags[jsonName] != nil { + constraint := jsonToDbTags[jsonName].constraint + if constraint == "unique" || constraint == "term" { + uniqueConstraintFound = true + u.Directive = pb.SchemaUpdate_INDEX + if constraint == "unique" { + u.Tokenizer = []string{"exact"} + } else { + u.Tokenizer = []string{"term"} + } + } + } + + sch.Preds = append(sch.Preds, u) + nquads = append(nquads, nquad) + } + if !uniqueConstraintFound { + return fmt.Errorf(NoUniqueConstr, t.Name()) + } + sch.Types = append(sch.Types, &pb.TypeUpdate{ + TypeName: addNamespace(n.id, t.Name()), + Fields: sch.Preds, + }) + + val, err := valueToApiVal(t.Name()) + if err != nil { + return err + } + typeNquad := &api.NQuad{ + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: "dgraph.type", + ObjectValue: val, + } + nquads = append(nquads, typeNquad) + + *dms = append(*dms, &dql.Mutation{ + Set: nquads, + }) + + return nil +} + +func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { + return []*dql.Mutation{{ + Del: []*api.NQuad{ + { + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: x.Star, + ObjectValue: &api.Value{ + Val: &api.Value_DefaultVal{DefaultVal: x.Star}, + }, + }, + }, + }} +} + +func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { + edges, err := query.ToDirectedEdges(dms, nil) + if err != nil { + return err + } + + if !db.isOpen { + return ErrClosedDB + } + + startTs, err := db.z.nextTs() + if err != nil { + return err + } + commitTs, err := db.z.nextTs() + if err != nil { + return err + } + + m := &pb.Mutations{ + GroupId: 1, + StartTs: startTs, + Edges: edges, + } + m.Edges, err = query.ExpandEdges(ctx, m) + if err != nil { + return fmt.Errorf("error expanding edges: %w", err) + } + + p := &pb.Proposal{Mutations: m, StartTs: startTs} + if err := worker.ApplyMutations(ctx, p); err != nil { + return err + } + + return worker.ApplyCommited(ctx, &pb.OracleDelta{ + Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, + }) +} + +func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { + gid, cf, err := getUniqueConstraint[T](object) + if err != nil { + return 0, err + } + + dms := make([]*dql.Mutation, 0) + sch := &schema.ParsedSchema{} + err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + if err != nil { + return 0, err + } + + err = n.alterSchemaWithParsed(ctx, sch) + if err != nil { + return 0, err + } + if gid != 0 { + gid, _, err = getByGidWithObject[T](ctx, n, gid, object) + if err != nil && err != ErrNoObjFound { + return 0, err + } + if err == nil { + return gid, nil + } + } else if cf != nil { + gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) + if err != nil && err != ErrNoObjFound { + return 0, err + } + if err == nil { + return gid, nil + } + } + + gid, err = db.z.nextUID() + if err != nil { + return 0, err + } + + dms = make([]*dql.Mutation, 0) + err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + if err != nil { + return 0, err + } + + err = applyDqlMutations(ctx, db, dms) + if err != nil { + return 0, err + } + + return gid, nil +} diff --git a/api_query_helper.go b/api_query_helper.go new file mode 100644 index 0000000..e62d8f3 --- /dev/null +++ b/api_query_helper.go @@ -0,0 +1,123 @@ +package modusdb + +import ( + "context" + "encoding/json" + "fmt" + "reflect" +) + +func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, error) { + return executeGet[T](ctx, n, gid) +} + +func getByGidWithObject[T any](ctx context.Context, n *Namespace, gid uint64, obj T) (uint64, *T, error) { + return executeGetWithObject[T](ctx, n, obj, false, gid) +} + +func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) { + return executeGet[T](ctx, n, cf) +} + +func getByConstrainedFieldWithObject[T any](ctx context.Context, n *Namespace, + cf ConstrainedField, obj T) (uint64, *T, error) { + + return executeGetWithObject[T](ctx, n, obj, false, cf) +} + +func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, args ...R) (uint64, *T, error) { + if len(args) != 1 { + return 0, nil, fmt.Errorf("expected 1 argument, got %d", len(args)) + } + + var obj T + + return executeGetWithObject(ctx, n, obj, true, args...) +} + +func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespace, + obj T, withReverse bool, args ...R) (uint64, *T, error) { + t := reflect.TypeOf(obj) + + fieldToJsonTags, jsonToDbTag, jsonToReverseEdgeTags, err := getFieldTags(t) + if err != nil { + return 0, nil, err + } + readFromQuery := "" + if withReverse { + for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { + readFromQuery += fmt.Sprintf(` + %s: ~%s { + uid + expand(_all_) + dgraph.type + } + `, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + } + } + + var cf ConstrainedField + var query string + gid, ok := any(args[0]).(uint64) + if ok { + query = formatObjQuery(buildUidQuery(gid), readFromQuery) + } else if cf, ok = any(args[0]).(ConstrainedField); ok { + query = formatObjQuery(buildEqQuery(getPredicateName(t.Name(), cf.Key), cf.Value), readFromQuery) + } else { + return 0, nil, fmt.Errorf("invalid unique field type") + } + + if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" { + return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key) + } + + resp, err := n.queryWithLock(ctx, query) + if err != nil { + return 0, nil, err + } + + dynamicType := createDynamicStruct(t, fieldToJsonTags, 1) + + dynamicInstance := reflect.New(dynamicType).Interface() + + var result struct { + Obj []any `json:"obj"` + } + + result.Obj = append(result.Obj, dynamicInstance) + + // Unmarshal the JSON response into the dynamic struct + if err := json.Unmarshal(resp.Json, &result); err != nil { + return 0, nil, err + } + + // Check if we have at least one object in the response + if len(result.Obj) == 0 { + return 0, nil, ErrNoObjFound + } + + // Map the dynamic struct to the final type T + finalObject := reflect.New(t).Interface() + gid, err = mapDynamicToFinal(result.Obj[0], finalObject) + if err != nil { + return 0, nil, err + } + + // Convert to *interface{} then to *T + if ifacePtr, ok := finalObject.(*interface{}); ok { + if typedPtr, ok := (*ifacePtr).(*T); ok { + return gid, typedPtr, nil + } + } + + // If conversion fails, try direct conversion + if typedPtr, ok := finalObject.(*T); ok { + return gid, typedPtr, nil + } + + if dirType, ok := finalObject.(T); ok { + return gid, &dirType, nil + } + + return 0, nil, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) +} diff --git a/api_reflect.go b/api_reflect.go index 4c84813..f74cd53 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -11,12 +11,12 @@ type dbTag struct { constraint string } -func getFieldTags(t reflect.Type) (jsonTags map[string]string, jsonToDbTags map[string]*dbTag, - reverseEdgeTags map[string]string, err error) { +func getFieldTags(t reflect.Type) (fieldToJsonTags map[string]string, + jsonToDbTags map[string]*dbTag, jsonToReverseEdgeTags map[string]string, err error) { - jsonTags = make(map[string]string) + fieldToJsonTags = make(map[string]string) jsonToDbTags = make(map[string]*dbTag) - reverseEdgeTags = make(map[string]string) + jsonToReverseEdgeTags = make(map[string]string) for i := 0; i < t.NumField(); i++ { field := t.Field(i) jsonTag := field.Tag.Get("json") @@ -24,7 +24,7 @@ func getFieldTags(t reflect.Type) (jsonTags map[string]string, jsonToDbTags map[ return nil, nil, nil, fmt.Errorf("field %s has no json tag", field.Name) } jsonName := strings.Split(jsonTag, ",")[0] - jsonTags[field.Name] = jsonName + fieldToJsonTags[field.Name] = jsonName reverseEdgeTag := field.Tag.Get("readFrom") if reverseEdgeTag != "" { @@ -35,7 +35,7 @@ func getFieldTags(t reflect.Type) (jsonTags map[string]string, jsonToDbTags map[ } t := strings.Split(typeAndField[0], "=")[1] f := strings.Split(typeAndField[1], "=")[1] - reverseEdgeTags[field.Name] = getPredicateName(t, f) + jsonToReverseEdgeTags[jsonName] = getPredicateName(t, f) } dbConstraintsTag := field.Tag.Get("db") @@ -50,13 +50,16 @@ func getFieldTags(t reflect.Type) (jsonTags map[string]string, jsonToDbTags map[ } } } - return jsonTags, jsonToDbTags, reverseEdgeTags, nil + return fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, nil } -func getFieldValues(object any, jsonFields map[string]string) map[string]any { +func getJsonTagToValues(object any, fieldToJsonTags map[string]string) map[string]any { values := make(map[string]any) - v := reflect.ValueOf(object).Elem() - for fieldName, jsonName := range jsonFields { + v := reflect.ValueOf(object) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + for fieldName, jsonName := range fieldToJsonTags { fieldValue := v.FieldByName(fieldName) values[jsonName] = fieldValue.Interface() @@ -64,16 +67,38 @@ func getFieldValues(object any, jsonFields map[string]string) map[string]any { return values } -func createDynamicStruct(t reflect.Type, jsonFields map[string]string) reflect.Type { - fields := make([]reflect.StructField, 0, len(jsonFields)) - for fieldName, jsonName := range jsonFields { +func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, depth int) reflect.Type { + fields := make([]reflect.StructField, 0, len(fieldToJsonTags)) + for fieldName, jsonName := range fieldToJsonTags { field, _ := t.FieldByName(fieldName) if fieldName != "Gid" { - fields = append(fields, reflect.StructField{ - Name: field.Name, - Type: field.Type, - Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), - }) + if field.Type.Kind() == reflect.Struct { + if depth <= 2 { + nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type) + nestedType := createDynamicStruct(field.Type, nestedFieldToJsonTags, depth+1) + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: nestedType, + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } + } else if field.Type.Kind() == reflect.Ptr && + field.Type.Elem().Kind() == reflect.Struct { + nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type.Elem()) + nestedType := createDynamicStruct(field.Type.Elem(), nestedFieldToJsonTags, depth+1) + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: reflect.PointerTo(nestedType), + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } else { + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: field.Type, + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } + } } fields = append(fields, reflect.StructField{ @@ -95,30 +120,80 @@ func mapDynamicToFinal(dynamic any, final any) (uint64, error) { gid := uint64(0) for i := 0; i < vDynamic.NumField(); i++ { - field := vDynamic.Type().Field(i) - value := vDynamic.Field(i) + + dynamicField := vDynamic.Type().Field(i) + dynamicFieldType := dynamicField.Type + dynamicValue := vDynamic.Field(i) var finalField reflect.Value - if field.Name == "Uid" { + if dynamicField.Name == "Uid" { finalField = vFinal.FieldByName("Gid") - gidStr := value.String() + gidStr := dynamicValue.String() gid, _ = strconv.ParseUint(gidStr, 0, 64) - } else if field.Name == "DgraphType" { - fieldArr := value.Interface().([]string) + } else if dynamicField.Name == "DgraphType" { + fieldArr := dynamicValue.Interface().([]string) if len(fieldArr) == 0 { return 0, ErrNoObjFound } } else { - finalField = vFinal.FieldByName(field.Name) + finalField = vFinal.FieldByName(dynamicField.Name) } - if finalField.IsValid() && finalField.CanSet() { - // if field name is uid, convert it to uint64 - if field.Name == "Uid" { - finalField.SetUint(gid) - } else { - finalField.Set(value) + if dynamicFieldType.Kind() == reflect.Struct { + _, err := mapDynamicToFinal(dynamicValue.Addr().Interface(), finalField.Addr().Interface()) + if err != nil { + return 0, err + } + } else if dynamicFieldType.Kind() == reflect.Ptr && + dynamicFieldType.Elem().Kind() == reflect.Struct { + // if field is a pointer, find if the underlying is a struct + _, err := mapDynamicToFinal(dynamicValue.Interface(), finalField.Interface()) + if err != nil { + return 0, err + } + + } else { + if finalField.IsValid() && finalField.CanSet() { + // if field name is uid, convert it to uint64 + if dynamicField.Name == "Uid" { + finalField.SetUint(gid) + } else { + finalField.Set(dynamicValue) + } } } } return gid, nil } + +func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { + t := reflect.TypeOf(object) + fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) + if err != nil { + return 0, nil, err + } + jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) + + for jsonName, value := range jsonTagToValue { + if jsonName == "gid" { + gid, ok := value.(uint64) + if !ok { + continue + } + if gid != 0 { + return gid, nil, nil + } + } + if jsonToDbTags[jsonName] != nil && jsonToDbTags[jsonName].constraint == "unique" { + // check if value is zero or nil + if value == reflect.Zero(reflect.TypeOf(value)).Interface() || value == nil { + continue + } + return 0, &ConstrainedField{ + Key: jsonName, + Value: value, + }, nil + } + } + + return 0, nil, fmt.Errorf(NoUniqueConstr, t.Name()) +} diff --git a/api_test.go b/api_test.go index 5b1c6bc..1d3293c 100644 --- a/api_test.go +++ b/api_test.go @@ -2,6 +2,7 @@ package modusdb_test import ( "context" + "fmt" "testing" "github.com/stretchr/testify/require" @@ -16,6 +17,52 @@ type User struct { ClerkId string `json:"clerk_id,omitempty" db:"constraint=unique"` } +func TestFirstTimeUser(t *testing.T) { + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + gid, user, err := modusdb.Create(db, &User{ + Name: "A", + Age: 10, + ClerkId: "123", + }) + + require.NoError(t, err) + require.Equal(t, user.Gid, gid) + require.Equal(t, "A", user.Name) + require.Equal(t, 10, user.Age) + require.Equal(t, "123", user.ClerkId) + + gid, queriedUser, err := modusdb.Get[User](db, gid) + + require.NoError(t, err) + require.Equal(t, queriedUser.Gid, gid) + require.Equal(t, 10, queriedUser.Age) + require.Equal(t, "A", queriedUser.Name) + require.Equal(t, "123", queriedUser.ClerkId) + + gid, queriedUser2, err := modusdb.Get[User](db, modusdb.ConstrainedField{ + Key: "clerk_id", + Value: "123", + }) + + require.NoError(t, err) + require.Equal(t, queriedUser.Gid, gid) + require.Equal(t, 10, queriedUser2.Age) + require.Equal(t, "A", queriedUser2.Name) + require.Equal(t, "123", queriedUser2.ClerkId) + + _, _, err = modusdb.Delete[User](db, gid) + require.NoError(t, err) + + _, queriedUser3, err := modusdb.Get[User](db, gid) + require.Error(t, err) + require.Equal(t, "no object found", err.Error()) + require.Nil(t, queriedUser3) + +} + func TestCreateApi(t *testing.T) { ctx := context.Background() db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) @@ -33,12 +80,11 @@ func TestCreateApi(t *testing.T) { ClerkId: "123", } - gid, _, err := modusdb.Create(db, user, db1.ID()) + gid, user, err := modusdb.Create(db, user, db1.ID()) require.NoError(t, err) require.Equal(t, "B", user.Name) - require.Equal(t, uint64(2), gid) - require.Equal(t, uint64(2), user.Gid) + require.Equal(t, user.Gid, gid) query := `{ me(func: has(User.name)) { @@ -118,8 +164,7 @@ func TestGetApi(t *testing.T) { gid, queriedUser, err := modusdb.Get[User](db, gid, db1.ID()) require.NoError(t, err) - require.Equal(t, uint64(2), gid) - require.Equal(t, uint64(2), queriedUser.Gid) + require.Equal(t, queriedUser.Gid, gid) require.Equal(t, 20, queriedUser.Age) require.Equal(t, "B", queriedUser.Name) require.Equal(t, "123", queriedUser.ClerkId) @@ -151,8 +196,7 @@ func TestGetApiWithConstrainedField(t *testing.T) { }, db1.ID()) require.NoError(t, err) - require.Equal(t, uint64(2), gid) - require.Equal(t, uint64(2), queriedUser.Gid) + require.Equal(t, queriedUser.Gid, gid) require.Equal(t, 20, queriedUser.Age) require.Equal(t, "B", queriedUser.Name) require.Equal(t, "123", queriedUser.ClerkId) @@ -194,3 +238,277 @@ func TestDeleteApi(t *testing.T) { require.Equal(t, "no object found", err.Error()) require.Nil(t, queriedUser) } + +func TestUpsertApi(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + user := &User{ + Name: "B", + Age: 20, + ClerkId: "123", + } + + gid, user, _, err := modusdb.Upsert(db, user, db1.ID()) + require.NoError(t, err) + require.Equal(t, user.Gid, gid) + + user.Age = 21 + gid, _, _, err = modusdb.Upsert(db, user, db1.ID()) + require.NoError(t, err) + require.Equal(t, user.Gid, gid) + + _, queriedUser, err := modusdb.Get[User](db, gid, db1.ID()) + require.NoError(t, err) + require.Equal(t, user.Gid, queriedUser.Gid) + require.Equal(t, 21, queriedUser.Age) + require.Equal(t, "B", queriedUser.Name) + require.Equal(t, "123", queriedUser.ClerkId) +} + +type Project struct { + Gid uint64 `json:"gid,omitempty"` + Name string `json:"name,omitempty"` + ClerkId string `json:"clerk_id,omitempty" db:"constraint=unique"` + // Branches []Branch `json:"branches,omitempty" readFrom:"type=Branch,field=proj"` +} + +type Branch struct { + Gid uint64 `json:"gid,omitempty"` + Name string `json:"name,omitempty"` + ClerkId string `json:"clerk_id,omitempty" db:"constraint=unique"` + Proj Project `json:"proj,omitempty"` +} + +func TestNestedObjectMutation(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + branch := &Branch{ + Name: "B", + ClerkId: "123", + Proj: Project{ + Name: "P", + ClerkId: "456", + }, + } + + gid, branch, err := modusdb.Create(db, branch, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "B", branch.Name) + require.Equal(t, branch.Gid, gid) + require.NotEqual(t, uint64(0), branch.Proj.Gid) + require.Equal(t, "P", branch.Proj.Name) + + query := `{ + me(func: has(Branch.name)) { + uid + Branch.name + Branch.clerk_id + Branch.proj { + uid + Project.name + Project.clerk_id + } + } + }` + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, + `{"me":[{"uid":"0x2","Branch.name":"B","Branch.clerk_id":"123","Branch.proj": + {"uid":"0x3","Project.name":"P","Project.clerk_id":"456"}}]}`, + string(resp.GetJson())) + + gid, queriedBranch, err := modusdb.Get[Branch](db, gid, db1.ID()) + require.NoError(t, err) + require.Equal(t, queriedBranch.Gid, gid) + require.Equal(t, "B", queriedBranch.Name) + +} + +func TestLinkingObjectsByConstrainedFields(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + projGid, project, err := modusdb.Create(db, &Project{ + Name: "P", + ClerkId: "456", + }, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "P", project.Name) + require.Equal(t, project.Gid, projGid) + + branch := &Branch{ + Name: "B", + ClerkId: "123", + Proj: Project{ + Name: "P", + ClerkId: "456", + }, + } + + gid, branch, err := modusdb.Create(db, branch, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "B", branch.Name) + require.Equal(t, branch.Gid, gid) + require.Equal(t, projGid, branch.Proj.Gid) + require.Equal(t, "P", branch.Proj.Name) + + query := `{ + me(func: has(Branch.name)) { + uid + Branch.name + Branch.clerk_id + Branch.proj { + uid + Project.name + Project.clerk_id + } + } + }` + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, + `{"me":[{"uid":"0x3","Branch.name":"B","Branch.clerk_id":"123","Branch.proj": + {"uid":"0x2","Project.name":"P","Project.clerk_id":"456"}}]}`, + string(resp.GetJson())) + + gid, queriedBranch, err := modusdb.Get[Branch](db, gid, db1.ID()) + require.NoError(t, err) + require.Equal(t, queriedBranch.Gid, gid) + require.Equal(t, "B", queriedBranch.Name) + +} + +func TestLinkingObjectsByGid(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + projGid, project, err := modusdb.Create(db, &Project{ + Name: "P", + ClerkId: "456", + }, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "P", project.Name) + require.Equal(t, project.Gid, projGid) + + branch := &Branch{ + Name: "B", + ClerkId: "123", + Proj: Project{ + Gid: projGid, + }, + } + + gid, branch, err := modusdb.Create(db, branch, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "B", branch.Name) + require.Equal(t, branch.Gid, gid) + require.Equal(t, projGid, branch.Proj.Gid) + require.Equal(t, "P", branch.Proj.Name) + + query := `{ + me(func: has(Branch.name)) { + uid + Branch.name + Branch.clerk_id + Branch.proj { + uid + Project.name + Project.clerk_id + } + } + }` + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, + `{"me":[{"uid":"0x3","Branch.name":"B","Branch.clerk_id":"123", + "Branch.proj":{"uid":"0x2","Project.name":"P","Project.clerk_id":"456"}}]}`, + string(resp.GetJson())) + + gid, queriedBranch, err := modusdb.Get[Branch](db, gid, db1.ID()) + require.NoError(t, err) + require.Equal(t, queriedBranch.Gid, gid) + require.Equal(t, "B", queriedBranch.Name) + +} + +type BadProject struct { + Name string `json:"name,omitempty"` + ClerkId string `json:"clerk_id,omitempty"` +} + +type BadBranch struct { + Gid uint64 `json:"gid,omitempty"` + Name string `json:"name,omitempty"` + ClerkId string `json:"clerk_id,omitempty" db:"constraint=unique"` + Proj BadProject `json:"proj,omitempty"` +} + +func TestNestedObjectMutationWithBadType(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + branch := &BadBranch{ + Name: "B", + ClerkId: "123", + Proj: BadProject{ + Name: "P", + ClerkId: "456", + }, + } + + _, _, err = modusdb.Create(db, branch, db1.ID()) + require.Error(t, err) + require.Equal(t, fmt.Sprintf(modusdb.NoUniqueConstr, "BadProject"), err.Error()) + + proj := &BadProject{ + Name: "P", + ClerkId: "456", + } + + _, _, err = modusdb.Create(db, proj, db1.ID()) + require.Error(t, err) + require.Equal(t, fmt.Sprintf(modusdb.NoUniqueConstr, "BadProject"), err.Error()) + +} diff --git a/api_types.go b/api_types.go index 9cea571..860edda 100644 --- a/api_types.go +++ b/api_types.go @@ -1,9 +1,21 @@ package modusdb -import "fmt" +import ( + "context" + "encoding/binary" + "fmt" + "time" + + "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/x" + "github.com/twpayne/go-geom" + "github.com/twpayne/go-geom/encoding/wkb" +) var ( - ErrNoObjFound = fmt.Errorf("no object found") + ErrNoObjFound = fmt.Errorf("no object found") + NoUniqueConstr = "unique constraint not defined for any field on type %s" ) type UniqueField interface { @@ -13,3 +25,108 @@ type ConstrainedField struct { Key string Value any } + +type ModusDbOption func(*modusDbOptions) + +type modusDbOptions struct { + namespace uint64 +} + +func WithNamespace(namespace uint64) ModusDbOption { + return func(o *modusDbOptions) { + o.namespace = namespace + } +} + +func getDefaultNamespace(db *DB, ns ...uint64) (context.Context, *Namespace, error) { + dbOpts := &modusDbOptions{ + namespace: db.defaultNamespace.ID(), + } + for _, ns := range ns { + WithNamespace(ns)(dbOpts) + } + + n, err := db.getNamespaceWithLock(dbOpts.namespace) + if err != nil { + return nil, nil, err + } + + ctx := context.Background() + ctx = x.AttachNamespace(ctx, n.ID()) + + return ctx, n, nil +} + +func valueToPosting_ValType(v any) (pb.Posting_ValType, error) { + switch v.(type) { + case string: + return pb.Posting_STRING, nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32: + return pb.Posting_INT, nil + case uint64: + return pb.Posting_UID, nil + case bool: + return pb.Posting_BOOL, nil + case float32, float64: + return pb.Posting_FLOAT, nil + case []byte: + return pb.Posting_BINARY, nil + case time.Time: + return pb.Posting_DATETIME, nil + case geom.Point: + return pb.Posting_GEO, nil + case []float32, []float64: + return pb.Posting_VFLOAT, nil + default: + return pb.Posting_DEFAULT, fmt.Errorf("unsupported type %T", v) + } +} + +func valueToApiVal(v any) (*api.Value, error) { + switch val := v.(type) { + case string: + return &api.Value{Val: &api.Value_StrVal{StrVal: val}}, nil + case int: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int8: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int16: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int32: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int64: + return &api.Value{Val: &api.Value_IntVal{IntVal: val}}, nil + case uint8: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint16: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint32: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint64: + return &api.Value{Val: &api.Value_UidVal{UidVal: val}}, nil + case bool: + return &api.Value{Val: &api.Value_BoolVal{BoolVal: val}}, nil + case float32: + return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: float64(val)}}, nil + case float64: + return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}}, nil + case []byte: + return &api.Value{Val: &api.Value_BytesVal{BytesVal: val}}, nil + case time.Time: + bytes, err := val.MarshalBinary() + if err != nil { + return nil, err + } + return &api.Value{Val: &api.Value_DateVal{DateVal: bytes}}, nil + case geom.Point: + bytes, err := wkb.Marshal(&val, binary.LittleEndian) + if err != nil { + return nil, err + } + return &api.Value{Val: &api.Value_GeoVal{GeoVal: bytes}}, nil + case uint: + return &api.Value{Val: &api.Value_DefaultVal{DefaultVal: fmt.Sprint(v)}}, nil + default: + return nil, fmt.Errorf("unsupported type %T", v) + } +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..133ee1d --- /dev/null +++ b/utils.go @@ -0,0 +1,15 @@ +package modusdb + +import ( + "fmt" + + "github.com/dgraph-io/dgraph/v24/x" +) + +func getPredicateName(typeName, fieldName string) string { + return fmt.Sprint(typeName, ".", fieldName) +} + +func addNamespace(ns uint64, pred string) string { + return x.NamespaceAttr(ns, pred) +}