From f9db3838018d31b11e9dd32c7de16b71c8c13a9e Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Mon, 23 Dec 2024 00:47:06 -0800 Subject: [PATCH 01/12] add nested mutation, upsert, nested types, reverse edge support --- api.go | 40 ++++++++-- api_helper.go | 206 +++++++++++++++++++++++++++++++++++++++---------- api_reflect.go | 20 ++--- 3 files changed, 212 insertions(+), 54 deletions(-) diff --git a/api.go b/api.go index 88a42b2..4966e5b 100644 --- a/api.go +++ b/api.go @@ -5,6 +5,8 @@ import ( "fmt" "reflect" + "github.com/dgraph-io/dgraph/v24/dql" + "github.com/dgraph-io/dgraph/v24/schema" "github.com/dgraph-io/dgraph/v24/x" ) @@ -55,7 +57,9 @@ func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) { return 0, object, err } - dms, sch, err := generateCreateDqlMutationsAndSchema(n, object, gid) + dms := make([]*dql.Mutation, 0) + sch := &schema.ParsedSchema{} + err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) if err != nil { return 0, object, err } @@ -91,11 +95,11 @@ func Get[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, return 0, nil, err } if uid, ok := any(uniqueField).(uint64); ok { - return getByGid[T](ctx, n, uid) + return getByGid[T](ctx, n, uid, true) } if cf, ok := any(uniqueField).(ConstrainedField); ok { - return getByConstrainedField[T](ctx, n, cf) + return getByConstrainedField[T](ctx, n, cf, true) } return 0, nil, fmt.Errorf("invalid unique field type") @@ -109,7 +113,7 @@ func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, return 0, nil, err } if uid, ok := any(uniqueField).(uint64); ok { - uid, obj, err := getByGid[T](ctx, n, uid) + uid, obj, err := getByGid[T](ctx, n, uid, true) if err != nil { return 0, nil, err } @@ -125,8 +129,34 @@ func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, } if cf, ok := any(uniqueField).(ConstrainedField); ok { - return getByConstrainedField[T](ctx, n, cf) + uid, obj, err := getByConstrainedField[T](ctx, n, cf, true) + if err != nil { + return 0, nil, err + } + + dms := generateDeleteDqlMutations(n, uid) + + err = applyDqlMutations(ctx, db, dms) + if err != nil { + return 0, nil, err + } + + return uid, obj, nil } return 0, nil, fmt.Errorf("invalid unique field type") } + +func Upsert[T any](db *DB, object *T, ns ...uint64) (uint64, *T, bool, error) { + db.mutex.Lock() + defer db.mutex.Unlock() + if len(ns) > 1 { + return 0, object, false, fmt.Errorf("only one namespace is allowed") + } + ctx, n, err := getDefaultNamespace(db, ns...) + if err != nil { + return 0, object, false, err + } + + return upsertHelper[T](ctx, db, n, object, true) +} diff --git a/api_helper.go b/api_helper.go index 0cda75b..6baad79 100644 --- a/api_helper.go +++ b/api_helper.go @@ -50,7 +50,7 @@ func valueToPosting_ValType(v any) (pb.Posting_ValType, error) { } } -func valueToValType(v any) (*api.Value, error) { +func valueToApiVal(v any) (*api.Value, error) { switch val := v.(type) { case string: return &api.Value{Val: &api.Value_StrVal{StrVal: val}}, nil @@ -97,42 +97,66 @@ func valueToValType(v any) (*api.Value, error) { } } -func generateCreateDqlMutationsAndSchema[T any](n *Namespace, object *T, - gid uint64) ([]*dql.Mutation, *schema.ParsedSchema, error) { +func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object *T, + gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error { t := reflect.TypeOf(*object) if t.Kind() != reflect.Struct { - return nil, nil, fmt.Errorf("expected struct, got %s", t.Kind()) + return fmt.Errorf("expected struct, got %s", t.Kind()) } - jsonFields, dbFields, _, err := getFieldTags(t) + fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) if err != nil { - return nil, nil, err + return err } - values := getFieldValues(object, jsonFields) - sch := &schema.ParsedSchema{} + values := getJsonTagToValues(object, fieldToJsonTags) nquads := make([]*api.NQuad, 0) for jsonName, value := range values { if jsonName == "gid" { continue } - valType, err := valueToPosting_ValType(value) - if err != nil { - return nil, nil, err + var val *api.Value + var valType pb.Posting_ValType + if reflect.TypeOf(value).Kind() == reflect.Struct { + gid, _, _, err := upsertHelper(ctx, n.db, n, &value, false) + if err != nil { + return err + } + valType, err = valueToPosting_ValType(fmt.Sprint(gid)) + if err != nil { + return err + } + val, err = valueToApiVal(fmt.Sprint(gid)) + if err != nil { + return err + } + } else { + valType, err = valueToPosting_ValType(value) + if err != nil { + return err + } + val, err = valueToApiVal(value) + if err != nil { + return err + } } u := &pb.SchemaUpdate{ Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), ValueType: valType, } - if dbFields[jsonName] != nil && dbFields[jsonName].constraint == "unique" { - u.Directive = pb.SchemaUpdate_INDEX - u.Tokenizer = []string{"exact"} + if jsonToDbTags[jsonName] != nil { + constraint := jsonToDbTags[jsonName].constraint + if constraint == "unique" || constraint == "term" { + u.Directive = pb.SchemaUpdate_INDEX + if constraint == "unique" { + u.Tokenizer = []string{"exact"} + } else { + u.Tokenizer = []string{"term"} + } + } } + sch.Preds = append(sch.Preds, u) - val, err := valueToValType(value) - if err != nil { - return nil, nil, err - } nquad := &api.NQuad{ Namespace: n.ID(), Subject: fmt.Sprint(gid), @@ -146,9 +170,9 @@ func generateCreateDqlMutationsAndSchema[T any](n *Namespace, object *T, Fields: sch.Preds, }) - val, err := valueToValType(t.Name()) + val, err := valueToApiVal(t.Name()) if err != nil { - return nil, nil, err + return err } nquad := &api.NQuad{ Namespace: n.ID(), @@ -158,12 +182,11 @@ func generateCreateDqlMutationsAndSchema[T any](n *Namespace, object *T, } nquads = append(nquads, nquad) - dms := make([]*dql.Mutation, 0) - dms = append(dms, &dql.Mutation{ + *dms = append(*dms, &dql.Mutation{ Set: nquads, }) - return dms, sch, nil + return nil } func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { @@ -181,48 +204,71 @@ func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { }} } -func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, error) { - query := fmt.Sprintf(` +func getByGid[T any](ctx context.Context, n *Namespace, gid uint64, readFrom bool) (uint64, *T, error) { + query := ` { obj(func: uid(%d)) { uid expand(_all_) dgraph.type + %s } } - `, gid) + ` - return executeGet[T](ctx, n, query, nil) + return executeGet[T](ctx, n, query, readFrom, gid) } -func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) { - var obj T - - t := reflect.TypeOf(obj) - query := fmt.Sprintf(` +func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField, readFrom bool) (uint64, *T, error) { + query := ` { obj(func: eq(%s, %s)) { uid expand(_all_) dgraph.type + %s } } - `, getPredicateName(t.Name(), cf.Key), cf.Value) + ` - return executeGet[T](ctx, n, query, &cf) + return executeGet[T](ctx, n, query, readFrom, cf) } -func executeGet[T any](ctx context.Context, n *Namespace, query string, cf *ConstrainedField) (uint64, *T, error) { +func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, query string, readFrom bool, args ...R) (uint64, *T, error) { + if len(args) != 1 { + return 0, nil, fmt.Errorf("expected 1 argument, got %d", len(args)) + } + var obj T t := reflect.TypeOf(obj) - jsonFields, dbTags, _, err := getFieldTags(t) + fieldToJsonTags, jsonToDbTag, reverseEdgeTags, err := getFieldTags(t) if err != nil { return 0, nil, err } + readFromQuery := "" + for fieldName, reverseEdgeTag := range reverseEdgeTags { + readFromQuery += fmt.Sprintf(` + %s: ~%s { + uid + expand(_all_) + dgraph.type + } + `, getPredicateName(t.Name(), fieldToJsonTags[fieldName]), reverseEdgeTag) + } - if cf != nil && dbTags[cf.Key].constraint == "" { + var cf ConstrainedField + gid, ok := any(args[0]).(uint64) + if ok { + query = fmt.Sprintf(query, gid, readFromQuery) + } else if cf, ok = any(args[0]).(ConstrainedField); ok { + query = fmt.Sprintf(query, getPredicateName(t.Name(), cf.Key), cf.Value, readFromQuery) + } else { + return 0, nil, fmt.Errorf("invalid unique field type") + } + + if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" { return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key) } @@ -231,7 +277,7 @@ func executeGet[T any](ctx context.Context, n *Namespace, query string, cf *Cons return 0, nil, err } - dynamicType := createDynamicStruct(t, jsonFields) + dynamicType := createDynamicStruct(t, fieldToJsonTags) dynamicInstance := reflect.New(dynamicType).Interface() @@ -253,7 +299,7 @@ func executeGet[T any](ctx context.Context, n *Namespace, query string, cf *Cons // Map the dynamic struct to the final type T finalObject := reflect.New(t).Interface() - gid, err := mapDynamicToFinal(result.Obj[0], finalObject) + gid, err = mapDynamicToFinal(result.Obj[0], finalObject) if err != nil { return 0, nil, err } @@ -299,3 +345,85 @@ func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, }) } + +func getUniqueConstraint[T any](object *T) (uint64, *ConstrainedField, error) { + t := reflect.TypeOf(*object) + fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) + if err != nil { + return 0, nil, err + } + values := getJsonTagToValues(object, fieldToJsonTags) + + for jsonName, value := range values { + if jsonName == "gid" { + gid, ok := value.(uint64) + if !ok { + return 0, nil, fmt.Errorf("expected uint64 type for gid, got %T", value) + } + if gid != 0 { + return gid, nil, nil + } + } + if jsonToDbTags[jsonName] != nil && jsonToDbTags[jsonName].constraint == "unique" { + return 0, &ConstrainedField{ + Key: jsonName, + Value: value, + }, nil + } + } + + return 0, nil, fmt.Errorf("unique constraint not defined for any field on type %s", t.Name()) +} + +func upsertHelper[T any](ctx context.Context, db *DB, n *Namespace, object *T, readFrom bool) (uint64, *T, bool, error) { + gid, cf, err := getUniqueConstraint(object) + if err != nil { + return 0, object, false, err + } + if gid != 0 { + gid, object, err := getByGid[T](ctx, n, gid, readFrom) + if err != nil { + return 0, object, false, err + } + return gid, object, true, nil + } else if cf != nil { + gid, object, err := getByConstrainedField[T](ctx, n, *cf, readFrom) + if err == nil { + return gid, object, true, nil + } + } + + gid, err = db.z.nextUID() + if err != nil { + return 0, object, false, err + } + + dms := make([]*dql.Mutation, 0) + sch := &schema.ParsedSchema{} + err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + if err != nil { + return 0, object, false, err + } + + ctx = x.AttachNamespace(ctx, n.ID()) + + err = n.alterSchemaWithParsed(ctx, sch) + if err != nil { + return 0, object, false, err + } + + err = applyDqlMutations(ctx, db, dms) + if err != nil { + return 0, object, false, err + } + + v := reflect.ValueOf(object).Elem() + + gidField := v.FieldByName("Gid") + + if gidField.IsValid() && gidField.CanSet() && gidField.Kind() == reflect.Uint64 { + gidField.SetUint(gid) + } + + return gid, object, false, nil +} diff --git a/api_reflect.go b/api_reflect.go index 4c84813..9e1ac2e 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -11,10 +11,10 @@ type dbTag struct { constraint string } -func getFieldTags(t reflect.Type) (jsonTags map[string]string, jsonToDbTags map[string]*dbTag, - reverseEdgeTags map[string]string, err error) { +func getFieldTags(t reflect.Type) (fieldToJsonTags map[string]string, + jsonToDbTags map[string]*dbTag, reverseEdgeTags map[string]string, err error) { - jsonTags = make(map[string]string) + fieldToJsonTags = make(map[string]string) jsonToDbTags = make(map[string]*dbTag) reverseEdgeTags = make(map[string]string) for i := 0; i < t.NumField(); i++ { @@ -24,7 +24,7 @@ func getFieldTags(t reflect.Type) (jsonTags map[string]string, jsonToDbTags map[ return nil, nil, nil, fmt.Errorf("field %s has no json tag", field.Name) } jsonName := strings.Split(jsonTag, ",")[0] - jsonTags[field.Name] = jsonName + fieldToJsonTags[field.Name] = jsonName reverseEdgeTag := field.Tag.Get("readFrom") if reverseEdgeTag != "" { @@ -50,13 +50,13 @@ func getFieldTags(t reflect.Type) (jsonTags map[string]string, jsonToDbTags map[ } } } - return jsonTags, jsonToDbTags, reverseEdgeTags, nil + return fieldToJsonTags, jsonToDbTags, reverseEdgeTags, nil } -func getFieldValues(object any, jsonFields map[string]string) map[string]any { +func getJsonTagToValues(object any, fieldToJsonTags map[string]string) map[string]any { values := make(map[string]any) v := reflect.ValueOf(object).Elem() - for fieldName, jsonName := range jsonFields { + for fieldName, jsonName := range fieldToJsonTags { fieldValue := v.FieldByName(fieldName) values[jsonName] = fieldValue.Interface() @@ -64,9 +64,9 @@ func getFieldValues(object any, jsonFields map[string]string) map[string]any { return values } -func createDynamicStruct(t reflect.Type, jsonFields map[string]string) reflect.Type { - fields := make([]reflect.StructField, 0, len(jsonFields)) - for fieldName, jsonName := range jsonFields { +func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string) reflect.Type { + fields := make([]reflect.StructField, 0, len(fieldToJsonTags)) + for fieldName, jsonName := range fieldToJsonTags { field, _ := t.FieldByName(fieldName) if fieldName != "Gid" { fields = append(fields, reflect.StructField{ From 5db64236bd751af1750e4f8ff127fe07d0b23f5b Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Mon, 23 Dec 2024 23:54:31 -0800 Subject: [PATCH 02/12] add nested mutations --- api.go | 26 ++------- api_helper.go | 154 +++++++++++++++++++++++++++++++------------------ api_reflect.go | 13 +++-- api_test.go | 148 +++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 252 insertions(+), 89 deletions(-) diff --git a/api.go b/api.go index 4966e5b..5588a1b 100644 --- a/api.go +++ b/api.go @@ -59,13 +59,11 @@ func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) { dms := make([]*dql.Mutation, 0) sch := &schema.ParsedSchema{} - err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + err = generateCreateDqlMutationsAndSchema(ctx, n, *object, gid, &dms, sch) if err != nil { return 0, object, err } - ctx = x.AttachNamespace(ctx, n.ID()) - err = n.alterSchemaWithParsed(ctx, sch) if err != nil { return 0, object, err @@ -95,11 +93,11 @@ func Get[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, return 0, nil, err } if uid, ok := any(uniqueField).(uint64); ok { - return getByGid[T](ctx, n, uid, true) + return getByGid[T](ctx, n, uid) } if cf, ok := any(uniqueField).(ConstrainedField); ok { - return getByConstrainedField[T](ctx, n, cf, true) + return getByConstrainedField[T](ctx, n, cf) } return 0, nil, fmt.Errorf("invalid unique field type") @@ -113,7 +111,7 @@ func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, return 0, nil, err } if uid, ok := any(uniqueField).(uint64); ok { - uid, obj, err := getByGid[T](ctx, n, uid, true) + uid, obj, err := getByGid[T](ctx, n, uid) if err != nil { return 0, nil, err } @@ -129,7 +127,7 @@ func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, } if cf, ok := any(uniqueField).(ConstrainedField); ok { - uid, obj, err := getByConstrainedField[T](ctx, n, cf, true) + uid, obj, err := getByConstrainedField[T](ctx, n, cf) if err != nil { return 0, nil, err } @@ -146,17 +144,3 @@ func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, return 0, nil, fmt.Errorf("invalid unique field type") } - -func Upsert[T any](db *DB, object *T, ns ...uint64) (uint64, *T, bool, error) { - db.mutex.Lock() - defer db.mutex.Unlock() - if len(ns) > 1 { - return 0, object, false, fmt.Errorf("only one namespace is allowed") - } - ctx, n, err := getDefaultNamespace(db, ns...) - if err != nil { - return 0, object, false, err - } - - return upsertHelper[T](ctx, db, n, object, true) -} diff --git a/api_helper.go b/api_helper.go index 6baad79..57d84ef 100644 --- a/api_helper.go +++ b/api_helper.go @@ -97,38 +97,44 @@ func valueToApiVal(v any) (*api.Value, error) { } } -func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object *T, +func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object T, gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error { - t := reflect.TypeOf(*object) + t := reflect.TypeOf(object) if t.Kind() != reflect.Struct { return fmt.Errorf("expected struct, got %s", t.Kind()) } - fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) + fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, err := getFieldTags(t) if err != nil { return err } - values := getJsonTagToValues(object, fieldToJsonTags) + jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) nquads := make([]*api.NQuad, 0) - for jsonName, value := range values { + for jsonName, value := range jsonTagToValue { + if jsonToReverseEdgeTags[jsonName] != "" { + continue + } if jsonName == "gid" { continue } var val *api.Value var valType pb.Posting_ValType - if reflect.TypeOf(value).Kind() == reflect.Struct { - gid, _, _, err := upsertHelper(ctx, n.db, n, &value, false) - if err != nil { - return err - } - valType, err = valueToPosting_ValType(fmt.Sprint(gid)) + + reflectValueType := reflect.TypeOf(value) + var nquad *api.NQuad + if reflectValueType.Kind() == reflect.Struct { + newGid, err := getUidOrMutate(ctx, n.db, n, value) if err != nil { return err } - val, err = valueToApiVal(fmt.Sprint(gid)) - if err != nil { - return err + valType = pb.Posting_UID + + nquad = &api.NQuad{ + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: getPredicateName(t.Name(), jsonName), + ObjectId: fmt.Sprint(newGid), } } else { valType, err = valueToPosting_ValType(value) @@ -139,6 +145,13 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac if err != nil { return err } + + nquad = &api.NQuad{ + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: getPredicateName(t.Name(), jsonName), + ObjectValue: val, + } } u := &pb.SchemaUpdate{ Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), @@ -157,12 +170,6 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac } sch.Preds = append(sch.Preds, u) - nquad := &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: getPredicateName(t.Name(), jsonName), - ObjectValue: val, - } nquads = append(nquads, nquad) } sch.Types = append(sch.Types, &pb.TypeUpdate{ @@ -204,7 +211,22 @@ func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { }} } -func getByGid[T any](ctx context.Context, n *Namespace, gid uint64, readFrom bool) (uint64, *T, error) { +func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, error) { + query := ` + { + obj(func: uid(%d)) { + uid + expand(_all_) + dgraph.type + %s + } + } + ` + + return executeGet[T](ctx, n, query, gid) +} + +func getByGidWithObject[T any](ctx context.Context, n *Namespace, gid uint64, obj T) (uint64, *T, error) { query := ` { obj(func: uid(%d)) { @@ -216,10 +238,25 @@ func getByGid[T any](ctx context.Context, n *Namespace, gid uint64, readFrom boo } ` - return executeGet[T](ctx, n, query, readFrom, gid) + return executeGetWithObject[T](ctx, n, query, obj, gid) +} + +func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) { + query := ` + { + obj(func: eq(%s, %s)) { + uid + expand(_all_) + dgraph.type + %s + } + } + ` + + return executeGet[T](ctx, n, query, cf) } -func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField, readFrom bool) (uint64, *T, error) { +func getByConstrainedFieldWithObject[T any](ctx context.Context, n *Namespace, cf ConstrainedField, obj T) (uint64, *T, error) { query := ` { obj(func: eq(%s, %s)) { @@ -231,16 +268,20 @@ func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf Constrai } ` - return executeGet[T](ctx, n, query, readFrom, cf) + return executeGetWithObject[T](ctx, n, query, obj, cf) } -func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, query string, readFrom bool, args ...R) (uint64, *T, error) { +func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, query string, args ...R) (uint64, *T, error) { if len(args) != 1 { return 0, nil, fmt.Errorf("expected 1 argument, got %d", len(args)) } var obj T + return executeGetWithObject(ctx, n, query, obj, args...) +} + +func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespace, query string, obj T, args ...R) (uint64, *T, error) { t := reflect.TypeOf(obj) fieldToJsonTags, jsonToDbTag, reverseEdgeTags, err := getFieldTags(t) @@ -346,15 +387,15 @@ func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { }) } -func getUniqueConstraint[T any](object *T) (uint64, *ConstrainedField, error) { - t := reflect.TypeOf(*object) +func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { + t := reflect.TypeOf(object) fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) if err != nil { return 0, nil, err } - values := getJsonTagToValues(object, fieldToJsonTags) + jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) - for jsonName, value := range values { + for jsonName, value := range jsonTagToValue { if jsonName == "gid" { gid, ok := value.(uint64) if !ok { @@ -375,55 +416,54 @@ func getUniqueConstraint[T any](object *T) (uint64, *ConstrainedField, error) { return 0, nil, fmt.Errorf("unique constraint not defined for any field on type %s", t.Name()) } -func upsertHelper[T any](ctx context.Context, db *DB, n *Namespace, object *T, readFrom bool) (uint64, *T, bool, error) { +func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { + // if object == nil { + // return 0, nil, false, fmt.Errorf("object is nil") + // } gid, cf, err := getUniqueConstraint(object) if err != nil { - return 0, object, false, err + return 0, err + } + + dms := make([]*dql.Mutation, 0) + sch := &schema.ParsedSchema{} + err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + if err != nil { + return 0, err + } + + err = n.alterSchemaWithParsed(ctx, sch) + if err != nil { + return 0, err } if gid != 0 { - gid, object, err := getByGid[T](ctx, n, gid, readFrom) + gid, _, err := getByGidWithObject[T](ctx, n, gid, object) if err != nil { - return 0, object, false, err + return 0, err } - return gid, object, true, nil + return gid, nil } else if cf != nil { - gid, object, err := getByConstrainedField[T](ctx, n, *cf, readFrom) + gid, _, err := getByConstrainedFieldWithObject[T](ctx, n, *cf, object) if err == nil { - return gid, object, true, nil + return gid, nil } } gid, err = db.z.nextUID() if err != nil { - return 0, object, false, err + return 0, err } - dms := make([]*dql.Mutation, 0) - sch := &schema.ParsedSchema{} + dms = make([]*dql.Mutation, 0) err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) if err != nil { - return 0, object, false, err - } - - ctx = x.AttachNamespace(ctx, n.ID()) - - err = n.alterSchemaWithParsed(ctx, sch) - if err != nil { - return 0, object, false, err + return 0, err } err = applyDqlMutations(ctx, db, dms) if err != nil { - return 0, object, false, err - } - - v := reflect.ValueOf(object).Elem() - - gidField := v.FieldByName("Gid") - - if gidField.IsValid() && gidField.CanSet() && gidField.Kind() == reflect.Uint64 { - gidField.SetUint(gid) + return 0, err } - return gid, object, false, nil + return gid, nil } diff --git a/api_reflect.go b/api_reflect.go index 9e1ac2e..3584e5a 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -12,11 +12,11 @@ type dbTag struct { } func getFieldTags(t reflect.Type) (fieldToJsonTags map[string]string, - jsonToDbTags map[string]*dbTag, reverseEdgeTags map[string]string, err error) { + jsonToDbTags map[string]*dbTag, jsonToReverseEdgeTags map[string]string, err error) { fieldToJsonTags = make(map[string]string) jsonToDbTags = make(map[string]*dbTag) - reverseEdgeTags = make(map[string]string) + jsonToReverseEdgeTags = make(map[string]string) for i := 0; i < t.NumField(); i++ { field := t.Field(i) jsonTag := field.Tag.Get("json") @@ -35,7 +35,7 @@ func getFieldTags(t reflect.Type) (fieldToJsonTags map[string]string, } t := strings.Split(typeAndField[0], "=")[1] f := strings.Split(typeAndField[1], "=")[1] - reverseEdgeTags[field.Name] = getPredicateName(t, f) + jsonToReverseEdgeTags[jsonName] = getPredicateName(t, f) } dbConstraintsTag := field.Tag.Get("db") @@ -50,12 +50,15 @@ func getFieldTags(t reflect.Type) (fieldToJsonTags map[string]string, } } } - return fieldToJsonTags, jsonToDbTags, reverseEdgeTags, nil + return fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, nil } func getJsonTagToValues(object any, fieldToJsonTags map[string]string) map[string]any { values := make(map[string]any) - v := reflect.ValueOf(object).Elem() + v := reflect.ValueOf(object) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } for fieldName, jsonName := range fieldToJsonTags { fieldValue := v.FieldByName(fieldName) values[jsonName] = fieldValue.Interface() diff --git a/api_test.go b/api_test.go index 5b1c6bc..c729de5 100644 --- a/api_test.go +++ b/api_test.go @@ -16,6 +16,51 @@ type User struct { ClerkId string `json:"clerk_id,omitempty" db:"constraint=unique"` } +func TestFirstTimeUser(t *testing.T) { + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + + gid, user, err := modusdb.Create(db, &User{ + Name: "A", + Age: 10, + ClerkId: "123", + }) + + require.NoError(t, err) + require.Equal(t, user.Gid, gid) + require.Equal(t, "A", user.Name) + require.Equal(t, 10, user.Age) + require.Equal(t, "123", user.ClerkId) + + gid, queriedUser, err := modusdb.Get[User](db, gid) + + require.NoError(t, err) + require.Equal(t, queriedUser.Gid, gid) + require.Equal(t, 10, queriedUser.Age) + require.Equal(t, "A", queriedUser.Name) + require.Equal(t, "123", queriedUser.ClerkId) + + gid, queriedUser2, err := modusdb.Get[User](db, modusdb.ConstrainedField{ + Key: "clerk_id", + Value: "123", + }) + + require.NoError(t, err) + require.Equal(t, queriedUser.Gid, gid) + require.Equal(t, 10, queriedUser2.Age) + require.Equal(t, "A", queriedUser2.Name) + require.Equal(t, "123", queriedUser2.ClerkId) + + _, _, err = modusdb.Delete[User](db, gid) + require.NoError(t, err) + + _, queriedUser3, err := modusdb.Get[User](db, gid) + require.Error(t, err) + require.Equal(t, "no object found", err.Error()) + require.Nil(t, queriedUser3) + +} + func TestCreateApi(t *testing.T) { ctx := context.Background() db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) @@ -37,8 +82,7 @@ func TestCreateApi(t *testing.T) { require.NoError(t, err) require.Equal(t, "B", user.Name) - require.Equal(t, uint64(2), gid) - require.Equal(t, uint64(2), user.Gid) + require.Equal(t, user.Gid, gid) query := `{ me(func: has(User.name)) { @@ -118,8 +162,7 @@ func TestGetApi(t *testing.T) { gid, queriedUser, err := modusdb.Get[User](db, gid, db1.ID()) require.NoError(t, err) - require.Equal(t, uint64(2), gid) - require.Equal(t, uint64(2), queriedUser.Gid) + require.Equal(t, queriedUser.Gid, gid) require.Equal(t, 20, queriedUser.Age) require.Equal(t, "B", queriedUser.Name) require.Equal(t, "123", queriedUser.ClerkId) @@ -151,8 +194,7 @@ func TestGetApiWithConstrainedField(t *testing.T) { }, db1.ID()) require.NoError(t, err) - require.Equal(t, uint64(2), gid) - require.Equal(t, uint64(2), queriedUser.Gid) + require.Equal(t, queriedUser.Gid, gid) require.Equal(t, 20, queriedUser.Age) require.Equal(t, "B", queriedUser.Name) require.Equal(t, "123", queriedUser.ClerkId) @@ -194,3 +236,97 @@ func TestDeleteApi(t *testing.T) { require.Equal(t, "no object found", err.Error()) require.Nil(t, queriedUser) } + +// func TestUpsertApi(t *testing.T) { +// ctx := context.Background() +// db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) +// require.NoError(t, err) +// defer db.Close() + +// db1, err := db.CreateNamespace() +// require.NoError(t, err) + +// require.NoError(t, db1.DropData(ctx)) + +// user := &User{ +// Name: "B", +// Age: 20, +// ClerkId: "123", +// } + +// gid, _, _, err := modusdb.Upsert(db, user, db1.ID()) +// require.NoError(t, err) +// require.Equal(t, user.Gid, gid) + +// user.Age = 21 +// gid, _, _, err = modusdb.Upsert(db, user, db1.ID()) +// require.NoError(t, err) +// require.Equal(t, user.Gid, gid) + +// _, queriedUser, err := modusdb.Get[User](db, gid, db1.ID()) +// require.NoError(t, err) +// require.Equal(t, uint64(2), queriedUser.Gid) +// require.Equal(t, 21, queriedUser.Age) +// require.Equal(t, "B", queriedUser.Name) +// require.Equal(t, "123", queriedUser.ClerkId) +// } + +type Project struct { + Gid uint64 `json:"gid,omitempty"` + Name string `json:"name,omitempty"` + ClerkId string `json:"clerk_id,omitempty" db:"constraint=unique"` + // Branches []Branch `json:"branches,omitempty" readFrom:"type=Branch,field=proj"` +} + +type Branch struct { + Gid uint64 `json:"gid,omitempty"` + Name string `json:"name,omitempty"` + ClerkId string `json:"clerk_id,omitempty" db:"constraint=unique"` + Proj Project `json:"proj,omitempty"` +} + +func TestCreateApiWithNestedObj(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + branch := &Branch{ + Name: "B", + ClerkId: "123", + Proj: Project{ + Name: "P", + ClerkId: "456", + }, + } + + gid, _, err := modusdb.Create(db, branch, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "B", branch.Name) + require.Equal(t, branch.Gid, gid) + require.Equal(t, "P", branch.Proj.Name) + + query := `{ + me(func: has(Branch.name)) { + uid + Branch.name + Branch.clerk_id + Branch.proj { + uid + Project.name + Project.clerk_id + } + } + }` + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, + `{"me":[{"uid":"0x2","Branch.name":"B","Branch.clerk_id":"123","Branch.proj":{"uid":"0x3","Project.name":"P","Project.clerk_id":"456"}}]}`, + string(resp.GetJson())) +} From ff0c92d04e1c72341376df74b50003e2f3664690 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Tue, 24 Dec 2024 16:44:19 -0800 Subject: [PATCH 03/12] add linking objects, reading from objects --- api.go | 50 +++------------ api_helper.go | 167 ++++++++++++++++++++----------------------------- api_reflect.go | 71 +++++++++++++++------ api_test.go | 73 ++++++++++++++++++++- api_types.go | 114 ++++++++++++++++++++++++++++++++- utils.go | 15 +++++ 6 files changed, 326 insertions(+), 164 deletions(-) create mode 100644 utils.go diff --git a/api.go b/api.go index 5588a1b..1bd31c1 100644 --- a/api.go +++ b/api.go @@ -1,46 +1,12 @@ package modusdb import ( - "context" "fmt" - "reflect" "github.com/dgraph-io/dgraph/v24/dql" "github.com/dgraph-io/dgraph/v24/schema" - "github.com/dgraph-io/dgraph/v24/x" ) -type ModusDbOption func(*modusDbOptions) - -type modusDbOptions struct { - namespace uint64 -} - -func WithNamespace(namespace uint64) ModusDbOption { - return func(o *modusDbOptions) { - o.namespace = namespace - } -} - -func getDefaultNamespace(db *DB, ns ...uint64) (context.Context, *Namespace, error) { - dbOpts := &modusDbOptions{ - namespace: db.defaultNamespace.ID(), - } - for _, ns := range ns { - WithNamespace(ns)(dbOpts) - } - - n, err := db.getNamespaceWithLock(dbOpts.namespace) - if err != nil { - return nil, nil, err - } - - ctx := context.Background() - ctx = x.AttachNamespace(ctx, n.ID()) - - return ctx, n, nil -} - func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) { db.mutex.Lock() defer db.mutex.Unlock() @@ -74,20 +40,15 @@ func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) { return 0, object, err } - v := reflect.ValueOf(object).Elem() - - gidField := v.FieldByName("Gid") - - if gidField.IsValid() && gidField.CanSet() && gidField.Kind() == reflect.Uint64 { - gidField.SetUint(gid) - } - - return gid, object, nil + return getByGid[T](ctx, n, gid) } func Get[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, error) { db.mutex.Lock() defer db.mutex.Unlock() + if len(ns) > 1 { + return 0, nil, fmt.Errorf("only one namespace is allowed") + } ctx, n, err := getDefaultNamespace(db, ns...) if err != nil { return 0, nil, err @@ -106,6 +67,9 @@ func Get[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, error) { db.mutex.Lock() defer db.mutex.Unlock() + if len(ns) > 1 { + return 0, nil, fmt.Errorf("only one namespace is allowed") + } ctx, n, err := getDefaultNamespace(db, ns...) if err != nil { return 0, nil, err diff --git a/api_helper.go b/api_helper.go index 57d84ef..2b7af2e 100644 --- a/api_helper.go +++ b/api_helper.go @@ -2,11 +2,9 @@ package modusdb import ( "context" - "encoding/binary" "encoding/json" "fmt" "reflect" - "time" "github.com/dgraph-io/dgo/v240/protos/api" "github.com/dgraph-io/dgraph/v24/dql" @@ -15,88 +13,8 @@ import ( "github.com/dgraph-io/dgraph/v24/schema" "github.com/dgraph-io/dgraph/v24/worker" "github.com/dgraph-io/dgraph/v24/x" - "github.com/twpayne/go-geom" - "github.com/twpayne/go-geom/encoding/wkb" ) -func getPredicateName(typeName, fieldName string) string { - return fmt.Sprint(typeName, ".", fieldName) -} - -func addNamespace(ns uint64, pred string) string { - return x.NamespaceAttr(ns, pred) -} - -func valueToPosting_ValType(v any) (pb.Posting_ValType, error) { - switch v.(type) { - case string: - return pb.Posting_STRING, nil - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return pb.Posting_INT, nil - case bool: - return pb.Posting_BOOL, nil - case float32, float64: - return pb.Posting_FLOAT, nil - case []byte: - return pb.Posting_BINARY, nil - case time.Time: - return pb.Posting_DATETIME, nil - case geom.Point: - return pb.Posting_GEO, nil - case []float32, []float64: - return pb.Posting_VFLOAT, nil - default: - return pb.Posting_DEFAULT, fmt.Errorf("unsupported type %T", v) - } -} - -func valueToApiVal(v any) (*api.Value, error) { - switch val := v.(type) { - case string: - return &api.Value{Val: &api.Value_StrVal{StrVal: val}}, nil - case int: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int8: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int16: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int32: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int64: - return &api.Value{Val: &api.Value_IntVal{IntVal: val}}, nil - case uint8: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case uint16: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case uint32: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case bool: - return &api.Value{Val: &api.Value_BoolVal{BoolVal: val}}, nil - case float32: - return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: float64(val)}}, nil - case float64: - return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}}, nil - case []byte: - return &api.Value{Val: &api.Value_BytesVal{BytesVal: val}}, nil - case time.Time: - bytes, err := val.MarshalBinary() - if err != nil { - return nil, err - } - return &api.Value{Val: &api.Value_DateVal{DateVal: bytes}}, nil - case geom.Point: - bytes, err := wkb.Marshal(&val, binary.LittleEndian) - if err != nil { - return nil, err - } - return &api.Value{Val: &api.Value_GeoVal{GeoVal: bytes}}, nil - case uint, uint64: - return &api.Value{Val: &api.Value_DefaultVal{DefaultVal: fmt.Sprint(v)}}, nil - default: - return nil, fmt.Errorf("unsupported type %T", v) - } -} - func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object T, gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error { t := reflect.TypeOf(object) @@ -124,6 +42,7 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac reflectValueType := reflect.TypeOf(value) var nquad *api.NQuad if reflectValueType.Kind() == reflect.Struct { + value = reflect.ValueOf(value).Interface() newGid, err := getUidOrMutate(ctx, n.db, n, value) if err != nil { return err @@ -136,6 +55,25 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac Predicate: getPredicateName(t.Name(), jsonName), ObjectId: fmt.Sprint(newGid), } + } else if reflectValueType.Kind() == reflect.Pointer { + // dereference the pointer + reflectValueType = reflectValueType.Elem() + if reflectValueType.Kind() == reflect.Struct { + // convert value to pointer, and then dereference + value = reflect.ValueOf(value).Elem().Interface() + newGid, err := getUidOrMutate(ctx, n.db, n, value) + if err != nil { + return err + } + valType = pb.Posting_UID + + nquad = &api.NQuad{ + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: getPredicateName(t.Name(), jsonName), + ObjectId: fmt.Sprint(newGid), + } + } } else { valType, err = valueToPosting_ValType(value) if err != nil { @@ -216,7 +154,10 @@ func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, { obj(func: uid(%d)) { uid - expand(_all_) + expand(_all_) { + uid + expand(_all_) + } dgraph.type %s } @@ -231,14 +172,17 @@ func getByGidWithObject[T any](ctx context.Context, n *Namespace, gid uint64, ob { obj(func: uid(%d)) { uid - expand(_all_) + expand(_all_) { + uid + expand(_all_) + } dgraph.type %s } } ` - return executeGetWithObject[T](ctx, n, query, obj, gid) + return executeGetWithObject[T](ctx, n, query, obj, false, gid) } func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) { @@ -246,7 +190,10 @@ func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf Constrai { obj(func: eq(%s, %s)) { uid - expand(_all_) + expand(_all_) { + uid + expand(_all_) + } dgraph.type %s } @@ -261,14 +208,17 @@ func getByConstrainedFieldWithObject[T any](ctx context.Context, n *Namespace, c { obj(func: eq(%s, %s)) { uid - expand(_all_) + expand(_all_) { + uid + expand(_all_) + } dgraph.type %s } } ` - return executeGetWithObject[T](ctx, n, query, obj, cf) + return executeGetWithObject[T](ctx, n, query, obj, false, cf) } func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, query string, args ...R) (uint64, *T, error) { @@ -278,25 +228,28 @@ func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, query s var obj T - return executeGetWithObject(ctx, n, query, obj, args...) + return executeGetWithObject(ctx, n, query, obj, true, args...) } -func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespace, query string, obj T, args ...R) (uint64, *T, error) { +func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespace, query string, + obj T, withReverse bool, args ...R) (uint64, *T, error) { t := reflect.TypeOf(obj) - fieldToJsonTags, jsonToDbTag, reverseEdgeTags, err := getFieldTags(t) + fieldToJsonTags, jsonToDbTag, jsonToReverseEdgeTags, err := getFieldTags(t) if err != nil { return 0, nil, err } readFromQuery := "" - for fieldName, reverseEdgeTag := range reverseEdgeTags { - readFromQuery += fmt.Sprintf(` + if withReverse { + for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { + readFromQuery += fmt.Sprintf(` %s: ~%s { uid expand(_all_) dgraph.type } - `, getPredicateName(t.Name(), fieldToJsonTags[fieldName]), reverseEdgeTag) + `, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + } } var cf ConstrainedField @@ -318,7 +271,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac return 0, nil, err } - dynamicType := createDynamicStruct(t, fieldToJsonTags) + dynamicType := createDynamicStruct(t, fieldToJsonTags, 1) dynamicInstance := reflect.New(dynamicType).Interface() @@ -345,7 +298,23 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac return 0, nil, err } - return gid, finalObject.(*T), nil + // Convert to *interface{} then to *T + if ifacePtr, ok := finalObject.(*interface{}); ok { + if typedPtr, ok := (*ifacePtr).(*T); ok { + return gid, typedPtr, nil + } + } + + // If conversion fails, try direct conversion + if typedPtr, ok := finalObject.(*T); ok { + return gid, typedPtr, nil + } + + if dirType, ok := finalObject.(T); ok { + return gid, &dirType, nil + } + + return 0, nil, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) } func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { @@ -417,9 +386,6 @@ func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { } func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { - // if object == nil { - // return 0, nil, false, fmt.Errorf("object is nil") - // } gid, cf, err := getUniqueConstraint(object) if err != nil { return 0, err @@ -444,6 +410,9 @@ func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) return gid, nil } else if cf != nil { gid, _, err := getByConstrainedFieldWithObject[T](ctx, n, *cf, object) + if err != nil && err != ErrNoObjFound { + return 0, err + } if err == nil { return gid, nil } diff --git a/api_reflect.go b/api_reflect.go index 3584e5a..d54a088 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -67,16 +67,38 @@ func getJsonTagToValues(object any, fieldToJsonTags map[string]string) map[strin return values } -func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string) reflect.Type { +func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, depth int) reflect.Type { fields := make([]reflect.StructField, 0, len(fieldToJsonTags)) for fieldName, jsonName := range fieldToJsonTags { field, _ := t.FieldByName(fieldName) if fieldName != "Gid" { - fields = append(fields, reflect.StructField{ - Name: field.Name, - Type: field.Type, - Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), - }) + if field.Type.Kind() == reflect.Struct { + if depth <= 2 { + nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type) + nestedType := createDynamicStruct(field.Type, nestedFieldToJsonTags, depth+1) + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: nestedType, + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } + } else if field.Type.Kind() == reflect.Ptr && + field.Type.Elem().Kind() == reflect.Struct { + nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type.Elem()) + nestedType := createDynamicStruct(field.Type.Elem(), nestedFieldToJsonTags, depth+1) + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: reflect.PointerTo(nestedType), + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } else { + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: field.Type, + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } + } } fields = append(fields, reflect.StructField{ @@ -98,28 +120,39 @@ func mapDynamicToFinal(dynamic any, final any) (uint64, error) { gid := uint64(0) for i := 0; i < vDynamic.NumField(); i++ { - field := vDynamic.Type().Field(i) - value := vDynamic.Field(i) + + dynamicField := vDynamic.Type().Field(i) + dynamicFieldType := dynamicField.Type + dynamicValue := vDynamic.Field(i) var finalField reflect.Value - if field.Name == "Uid" { + if dynamicField.Name == "Uid" { finalField = vFinal.FieldByName("Gid") - gidStr := value.String() + gidStr := dynamicValue.String() gid, _ = strconv.ParseUint(gidStr, 0, 64) - } else if field.Name == "DgraphType" { - fieldArr := value.Interface().([]string) + } else if dynamicField.Name == "DgraphType" { + fieldArr := dynamicValue.Interface().([]string) if len(fieldArr) == 0 { return 0, ErrNoObjFound } } else { - finalField = vFinal.FieldByName(field.Name) + finalField = vFinal.FieldByName(dynamicField.Name) } - if finalField.IsValid() && finalField.CanSet() { - // if field name is uid, convert it to uint64 - if field.Name == "Uid" { - finalField.SetUint(gid) - } else { - finalField.Set(value) + if dynamicFieldType.Kind() == reflect.Struct { + mapDynamicToFinal(dynamicValue.Addr().Interface(), finalField.Addr().Interface()) + } else if dynamicFieldType.Kind() == reflect.Ptr && + dynamicFieldType.Elem().Kind() == reflect.Struct { + // if field is a pointer, find if the underlying is a struct + mapDynamicToFinal(dynamicValue.Interface(), finalField.Interface()) + + } else { + if finalField.IsValid() && finalField.CanSet() { + // if field name is uid, convert it to uint64 + if dynamicField.Name == "Uid" { + finalField.SetUint(gid) + } else { + finalField.Set(dynamicValue) + } } } } diff --git a/api_test.go b/api_test.go index c729de5..7d92ca0 100644 --- a/api_test.go +++ b/api_test.go @@ -285,7 +285,7 @@ type Branch struct { Proj Project `json:"proj,omitempty"` } -func TestCreateApiWithNestedObj(t *testing.T) { +func TestNestedObjectMutation(t *testing.T) { ctx := context.Background() db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) require.NoError(t, err) @@ -305,11 +305,12 @@ func TestCreateApiWithNestedObj(t *testing.T) { }, } - gid, _, err := modusdb.Create(db, branch, db1.ID()) + gid, branch, err := modusdb.Create(db, branch, db1.ID()) require.NoError(t, err) require.Equal(t, "B", branch.Name) require.Equal(t, branch.Gid, gid) + require.NotEqual(t, uint64(0), branch.Proj.Gid) require.Equal(t, "P", branch.Proj.Name) query := `{ @@ -329,4 +330,72 @@ func TestCreateApiWithNestedObj(t *testing.T) { require.JSONEq(t, `{"me":[{"uid":"0x2","Branch.name":"B","Branch.clerk_id":"123","Branch.proj":{"uid":"0x3","Project.name":"P","Project.clerk_id":"456"}}]}`, string(resp.GetJson())) + + gid, queriedBranch, err := modusdb.Get[Branch](db, gid, db1.ID()) + require.NoError(t, err) + require.Equal(t, queriedBranch.Gid, gid) + require.Equal(t, "B", queriedBranch.Name) + +} + +func TestLinkingObjects(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + projGid, project, err := modusdb.Create(db, &Project{ + Name: "P", + ClerkId: "456", + }, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "P", project.Name) + require.Equal(t, project.Gid, projGid) + + branch := &Branch{ + Name: "B", + ClerkId: "123", + Proj: Project{ + Name: "P", + ClerkId: "456", + }, + } + + gid, branch, err := modusdb.Create(db, branch, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "B", branch.Name) + require.Equal(t, branch.Gid, gid) + require.Equal(t, projGid, branch.Proj.Gid) + require.Equal(t, "P", branch.Proj.Name) + + query := `{ + me(func: has(Branch.name)) { + uid + Branch.name + Branch.clerk_id + Branch.proj { + uid + Project.name + Project.clerk_id + } + } + }` + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, + `{"me":[{"uid":"0x3","Branch.name":"B","Branch.clerk_id":"123","Branch.proj":{"uid":"0x2","Project.name":"P","Project.clerk_id":"456"}}]}`, + string(resp.GetJson())) + + gid, queriedBranch, err := modusdb.Get[Branch](db, gid, db1.ID()) + require.NoError(t, err) + require.Equal(t, queriedBranch.Gid, gid) + require.Equal(t, "B", queriedBranch.Name) + } diff --git a/api_types.go b/api_types.go index 9cea571..6d3dc98 100644 --- a/api_types.go +++ b/api_types.go @@ -1,6 +1,17 @@ package modusdb -import "fmt" +import ( + "context" + "encoding/binary" + "fmt" + "time" + + "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/x" + "github.com/twpayne/go-geom" + "github.com/twpayne/go-geom/encoding/wkb" +) var ( ErrNoObjFound = fmt.Errorf("no object found") @@ -13,3 +24,104 @@ type ConstrainedField struct { Key string Value any } + +type ModusDbOption func(*modusDbOptions) + +type modusDbOptions struct { + namespace uint64 +} + +func WithNamespace(namespace uint64) ModusDbOption { + return func(o *modusDbOptions) { + o.namespace = namespace + } +} + +func getDefaultNamespace(db *DB, ns ...uint64) (context.Context, *Namespace, error) { + dbOpts := &modusDbOptions{ + namespace: db.defaultNamespace.ID(), + } + for _, ns := range ns { + WithNamespace(ns)(dbOpts) + } + + n, err := db.getNamespaceWithLock(dbOpts.namespace) + if err != nil { + return nil, nil, err + } + + ctx := context.Background() + ctx = x.AttachNamespace(ctx, n.ID()) + + return ctx, n, nil +} + +func valueToPosting_ValType(v any) (pb.Posting_ValType, error) { + switch v.(type) { + case string: + return pb.Posting_STRING, nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return pb.Posting_INT, nil + case bool: + return pb.Posting_BOOL, nil + case float32, float64: + return pb.Posting_FLOAT, nil + case []byte: + return pb.Posting_BINARY, nil + case time.Time: + return pb.Posting_DATETIME, nil + case geom.Point: + return pb.Posting_GEO, nil + case []float32, []float64: + return pb.Posting_VFLOAT, nil + default: + return pb.Posting_DEFAULT, fmt.Errorf("unsupported type %T", v) + } +} + +func valueToApiVal(v any) (*api.Value, error) { + switch val := v.(type) { + case string: + return &api.Value{Val: &api.Value_StrVal{StrVal: val}}, nil + case int: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int8: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int16: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int32: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int64: + return &api.Value{Val: &api.Value_IntVal{IntVal: val}}, nil + case uint8: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint16: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint32: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case bool: + return &api.Value{Val: &api.Value_BoolVal{BoolVal: val}}, nil + case float32: + return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: float64(val)}}, nil + case float64: + return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}}, nil + case []byte: + return &api.Value{Val: &api.Value_BytesVal{BytesVal: val}}, nil + case time.Time: + bytes, err := val.MarshalBinary() + if err != nil { + return nil, err + } + return &api.Value{Val: &api.Value_DateVal{DateVal: bytes}}, nil + case geom.Point: + bytes, err := wkb.Marshal(&val, binary.LittleEndian) + if err != nil { + return nil, err + } + return &api.Value{Val: &api.Value_GeoVal{GeoVal: bytes}}, nil + case uint, uint64: + return &api.Value{Val: &api.Value_DefaultVal{DefaultVal: fmt.Sprint(v)}}, nil + default: + return nil, fmt.Errorf("unsupported type %T", v) + } +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..133ee1d --- /dev/null +++ b/utils.go @@ -0,0 +1,15 @@ +package modusdb + +import ( + "fmt" + + "github.com/dgraph-io/dgraph/v24/x" +) + +func getPredicateName(typeName, fieldName string) string { + return fmt.Sprint(typeName, ".", fieldName) +} + +func addNamespace(ns uint64, pred string) string { + return x.NamespaceAttr(ns, pred) +} From 571f0b43c3801310f99bd8801b296014ed993226 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Tue, 24 Dec 2024 17:08:40 -0800 Subject: [PATCH 04/12] add upsert --- api.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++- api_helper.go | 12 +++++---- api_test.go | 66 ++++++++++++++++++++++----------------------- 3 files changed, 114 insertions(+), 39 deletions(-) diff --git a/api.go b/api.go index 1bd31c1..7fcdccd 100644 --- a/api.go +++ b/api.go @@ -25,7 +25,7 @@ func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) { dms := make([]*dql.Mutation, 0) sch := &schema.ParsedSchema{} - err = generateCreateDqlMutationsAndSchema(ctx, n, *object, gid, &dms, sch) + err = generateCreateDqlMutationsAndSchema[T](ctx, n, *object, gid, &dms, sch) if err != nil { return 0, object, err } @@ -43,6 +43,79 @@ func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) { return getByGid[T](ctx, n, gid) } +func Upsert[T any](db *DB, object *T, ns ...uint64) (uint64, *T, bool, error) { + + var wasFound bool + db.mutex.Lock() + defer db.mutex.Unlock() + if len(ns) > 1 { + return 0, object, false, fmt.Errorf("only one namespace is allowed") + } + if object == nil { + return 0, nil, false, fmt.Errorf("object is nil") + } + + ctx, n, err := getDefaultNamespace(db, ns...) + if err != nil { + return 0, object, false, err + } + + gid, cf, err := getUniqueConstraint[T](*object) + if err != nil { + return 0, nil, false, err + } + + dms := make([]*dql.Mutation, 0) + sch := &schema.ParsedSchema{} + err = generateCreateDqlMutationsAndSchema[T](ctx, n, *object, gid, &dms, sch) + if err != nil { + return 0, nil, false, err + } + + err = n.alterSchemaWithParsed(ctx, sch) + if err != nil { + return 0, nil, false, err + } + + if gid != 0 { + gid, _, err = getByGidWithObject[T](ctx, n, gid, *object) + if err != nil && err != ErrNoObjFound { + return 0, nil, false, err + } + wasFound = err == nil + } else if cf != nil { + gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, *object) + if err != nil && err != ErrNoObjFound { + return 0, nil, false, err + } + wasFound = err == nil + } + if gid == 0 { + gid, err = db.z.nextUID() + if err != nil { + return 0, nil, false, err + } + } + + dms = make([]*dql.Mutation, 0) + err = generateCreateDqlMutationsAndSchema[T](ctx, n, *object, gid, &dms, sch) + if err != nil { + return 0, nil, false, err + } + + err = applyDqlMutations(ctx, db, dms) + if err != nil { + return 0, nil, false, err + } + + gid, object, err = getByGid[T](ctx, n, gid) + if err != nil { + return 0, nil, false, err + } + + return gid, object, wasFound, nil +} + func Get[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, error) { db.mutex.Lock() defer db.mutex.Unlock() diff --git a/api_helper.go b/api_helper.go index 2b7af2e..088dc1f 100644 --- a/api_helper.go +++ b/api_helper.go @@ -386,7 +386,7 @@ func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { } func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { - gid, cf, err := getUniqueConstraint(object) + gid, cf, err := getUniqueConstraint[T](object) if err != nil { return 0, err } @@ -403,13 +403,15 @@ func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) return 0, err } if gid != 0 { - gid, _, err := getByGidWithObject[T](ctx, n, gid, object) - if err != nil { + gid, _, err = getByGidWithObject[T](ctx, n, gid, object) + if err != nil && err != ErrNoObjFound { return 0, err } - return gid, nil + if err == nil { + return gid, nil + } } else if cf != nil { - gid, _, err := getByConstrainedFieldWithObject[T](ctx, n, *cf, object) + gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) if err != nil && err != ErrNoObjFound { return 0, err } diff --git a/api_test.go b/api_test.go index 7d92ca0..42fa49c 100644 --- a/api_test.go +++ b/api_test.go @@ -237,39 +237,39 @@ func TestDeleteApi(t *testing.T) { require.Nil(t, queriedUser) } -// func TestUpsertApi(t *testing.T) { -// ctx := context.Background() -// db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) -// require.NoError(t, err) -// defer db.Close() - -// db1, err := db.CreateNamespace() -// require.NoError(t, err) - -// require.NoError(t, db1.DropData(ctx)) - -// user := &User{ -// Name: "B", -// Age: 20, -// ClerkId: "123", -// } - -// gid, _, _, err := modusdb.Upsert(db, user, db1.ID()) -// require.NoError(t, err) -// require.Equal(t, user.Gid, gid) - -// user.Age = 21 -// gid, _, _, err = modusdb.Upsert(db, user, db1.ID()) -// require.NoError(t, err) -// require.Equal(t, user.Gid, gid) - -// _, queriedUser, err := modusdb.Get[User](db, gid, db1.ID()) -// require.NoError(t, err) -// require.Equal(t, uint64(2), queriedUser.Gid) -// require.Equal(t, 21, queriedUser.Age) -// require.Equal(t, "B", queriedUser.Name) -// require.Equal(t, "123", queriedUser.ClerkId) -// } +func TestUpsertApi(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + user := &User{ + Name: "B", + Age: 20, + ClerkId: "123", + } + + gid, user, _, err := modusdb.Upsert(db, user, db1.ID()) + require.NoError(t, err) + require.Equal(t, user.Gid, gid) + + user.Age = 21 + gid, _, _, err = modusdb.Upsert(db, user, db1.ID()) + require.NoError(t, err) + require.Equal(t, user.Gid, gid) + + _, queriedUser, err := modusdb.Get[User](db, gid, db1.ID()) + require.NoError(t, err) + require.Equal(t, user.Gid, queriedUser.Gid) + require.Equal(t, 21, queriedUser.Age) + require.Equal(t, "B", queriedUser.Name) + require.Equal(t, "123", queriedUser.ClerkId) +} type Project struct { Gid uint64 `json:"gid,omitempty"` From bb924d749f91c090260b73933b3fab408c350bc6 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Tue, 24 Dec 2024 17:12:09 -0800 Subject: [PATCH 05/12] pass ci --- api_reflect.go | 10 ++++++++-- api_test.go | 3 ++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/api_reflect.go b/api_reflect.go index d54a088..798b9d4 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -139,11 +139,17 @@ func mapDynamicToFinal(dynamic any, final any) (uint64, error) { finalField = vFinal.FieldByName(dynamicField.Name) } if dynamicFieldType.Kind() == reflect.Struct { - mapDynamicToFinal(dynamicValue.Addr().Interface(), finalField.Addr().Interface()) + _, err := mapDynamicToFinal(dynamicValue.Addr().Interface(), finalField.Addr().Interface()) + if err != nil { + return 0, err + } } else if dynamicFieldType.Kind() == reflect.Ptr && dynamicFieldType.Elem().Kind() == reflect.Struct { // if field is a pointer, find if the underlying is a struct - mapDynamicToFinal(dynamicValue.Interface(), finalField.Interface()) + _, err := mapDynamicToFinal(dynamicValue.Interface(), finalField.Interface()) + if err != nil { + return 0, err + } } else { if finalField.IsValid() && finalField.CanSet() { diff --git a/api_test.go b/api_test.go index 42fa49c..3a80ecd 100644 --- a/api_test.go +++ b/api_test.go @@ -19,6 +19,7 @@ type User struct { func TestFirstTimeUser(t *testing.T) { db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) require.NoError(t, err) + defer db.Close() gid, user, err := modusdb.Create(db, &User{ Name: "A", @@ -78,7 +79,7 @@ func TestCreateApi(t *testing.T) { ClerkId: "123", } - gid, _, err := modusdb.Create(db, user, db1.ID()) + gid, user, err := modusdb.Create(db, user, db1.ID()) require.NoError(t, err) require.Equal(t, "B", user.Name) From c2cc8f60cd68b618f7a681fcb81d7e4615928929 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Tue, 24 Dec 2024 17:28:07 -0800 Subject: [PATCH 06/12] fix tests --- api_helper.go | 4 ++++ api_test.go | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/api_helper.go b/api_helper.go index 088dc1f..8365083 100644 --- a/api_helper.go +++ b/api_helper.go @@ -157,6 +157,7 @@ func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, expand(_all_) { uid expand(_all_) + dgraph.type } dgraph.type %s @@ -175,6 +176,7 @@ func getByGidWithObject[T any](ctx context.Context, n *Namespace, gid uint64, ob expand(_all_) { uid expand(_all_) + dgraph.type } dgraph.type %s @@ -193,6 +195,7 @@ func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf Constrai expand(_all_) { uid expand(_all_) + dgraph.type } dgraph.type %s @@ -211,6 +214,7 @@ func getByConstrainedFieldWithObject[T any](ctx context.Context, n *Namespace, c expand(_all_) { uid expand(_all_) + dgraph.type } dgraph.type %s diff --git a/api_test.go b/api_test.go index 3a80ecd..8aaf0a6 100644 --- a/api_test.go +++ b/api_test.go @@ -339,7 +339,7 @@ func TestNestedObjectMutation(t *testing.T) { } -func TestLinkingObjects(t *testing.T) { +func TestLinkingObjectsByConstrainedFields(t *testing.T) { ctx := context.Background() db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) require.NoError(t, err) @@ -400,3 +400,64 @@ func TestLinkingObjects(t *testing.T) { require.Equal(t, "B", queriedBranch.Name) } + +func TestLinkingObjectsByGid(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + projGid, project, err := modusdb.Create(db, &Project{ + Name: "P", + ClerkId: "456", + }, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "P", project.Name) + require.Equal(t, project.Gid, projGid) + + branch := &Branch{ + Name: "B", + ClerkId: "123", + Proj: Project{ + Gid: projGid, + }, + } + + gid, branch, err := modusdb.Create(db, branch, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "B", branch.Name) + require.Equal(t, branch.Gid, gid) + require.Equal(t, projGid, branch.Proj.Gid) + require.Equal(t, "P", branch.Proj.Name) + + query := `{ + me(func: has(Branch.name)) { + uid + Branch.name + Branch.clerk_id + Branch.proj { + uid + Project.name + Project.clerk_id + } + } + }` + resp, err := db1.Query(ctx, query) + require.NoError(t, err) + require.JSONEq(t, + `{"me":[{"uid":"0x3","Branch.name":"B","Branch.clerk_id":"123","Branch.proj":{"uid":"0x2","Project.name":"P","Project.clerk_id":"456"}}]}`, + string(resp.GetJson())) + + gid, queriedBranch, err := modusdb.Get[Branch](db, gid, db1.ID()) + require.NoError(t, err) + require.Equal(t, queriedBranch.Gid, gid) + require.Equal(t, "B", queriedBranch.Name) + +} From 786d10882939c0b1e33e9d7751b68d6e76188950 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Tue, 24 Dec 2024 17:29:07 -0800 Subject: [PATCH 07/12] lint --- api_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/api_test.go b/api_test.go index 8aaf0a6..cd913db 100644 --- a/api_test.go +++ b/api_test.go @@ -329,7 +329,8 @@ func TestNestedObjectMutation(t *testing.T) { resp, err := db1.Query(ctx, query) require.NoError(t, err) require.JSONEq(t, - `{"me":[{"uid":"0x2","Branch.name":"B","Branch.clerk_id":"123","Branch.proj":{"uid":"0x3","Project.name":"P","Project.clerk_id":"456"}}]}`, + `{"me":[{"uid":"0x2","Branch.name":"B","Branch.clerk_id":"123","Branch.proj": + {"uid":"0x3","Project.name":"P","Project.clerk_id":"456"}}]}`, string(resp.GetJson())) gid, queriedBranch, err := modusdb.Get[Branch](db, gid, db1.ID()) @@ -391,7 +392,8 @@ func TestLinkingObjectsByConstrainedFields(t *testing.T) { resp, err := db1.Query(ctx, query) require.NoError(t, err) require.JSONEq(t, - `{"me":[{"uid":"0x3","Branch.name":"B","Branch.clerk_id":"123","Branch.proj":{"uid":"0x2","Project.name":"P","Project.clerk_id":"456"}}]}`, + `{"me":[{"uid":"0x3","Branch.name":"B","Branch.clerk_id":"123","Branch.proj": + {"uid":"0x2","Project.name":"P","Project.clerk_id":"456"}}]}`, string(resp.GetJson())) gid, queriedBranch, err := modusdb.Get[Branch](db, gid, db1.ID()) From 2e9eb3b621d49fe6654635d9fea6e6eb35345133 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Tue, 24 Dec 2024 17:42:46 -0800 Subject: [PATCH 08/12] . --- api_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api_test.go b/api_test.go index cd913db..558ace9 100644 --- a/api_test.go +++ b/api_test.go @@ -454,7 +454,8 @@ func TestLinkingObjectsByGid(t *testing.T) { resp, err := db1.Query(ctx, query) require.NoError(t, err) require.JSONEq(t, - `{"me":[{"uid":"0x3","Branch.name":"B","Branch.clerk_id":"123","Branch.proj":{"uid":"0x2","Project.name":"P","Project.clerk_id":"456"}}]}`, + `{"me":[{"uid":"0x3","Branch.name":"B","Branch.clerk_id":"123", + "Branch.proj":{"uid":"0x2","Project.name":"P","Project.clerk_id":"456"}}]}`, string(resp.GetJson())) gid, queriedBranch, err := modusdb.Get[Branch](db, gid, db1.ID()) From 2ca2dcfcca92901697cd10e26867ef1a3c2d6c42 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Wed, 25 Dec 2024 00:42:44 -0800 Subject: [PATCH 09/12] fix test and lint --- api_helper.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/api_helper.go b/api_helper.go index 8365083..5bcf15e 100644 --- a/api_helper.go +++ b/api_helper.go @@ -206,7 +206,9 @@ func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf Constrai return executeGet[T](ctx, n, query, cf) } -func getByConstrainedFieldWithObject[T any](ctx context.Context, n *Namespace, cf ConstrainedField, obj T) (uint64, *T, error) { +func getByConstrainedFieldWithObject[T any](ctx context.Context, n *Namespace, + cf ConstrainedField, obj T) (uint64, *T, error) { + query := ` { obj(func: eq(%s, %s)) { @@ -372,13 +374,17 @@ func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { if jsonName == "gid" { gid, ok := value.(uint64) if !ok { - return 0, nil, fmt.Errorf("expected uint64 type for gid, got %T", value) + continue } if gid != 0 { return gid, nil, nil } } if jsonToDbTags[jsonName] != nil && jsonToDbTags[jsonName].constraint == "unique" { + // check if value is zero or nil + if value == reflect.Zero(reflect.TypeOf(value)).Interface() || value == nil { + continue + } return 0, &ConstrainedField{ Key: jsonName, Value: value, From 5acb47cd960f9f1f30c151347a841c87bb37bec8 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Wed, 25 Dec 2024 07:44:40 -0800 Subject: [PATCH 10/12] add requirements on unique constraint --- api_helper.go | 8 +++++++- api_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ api_types.go | 3 ++- 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/api_helper.go b/api_helper.go index 5bcf15e..0314d08 100644 --- a/api_helper.go +++ b/api_helper.go @@ -29,11 +29,13 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) nquads := make([]*api.NQuad, 0) + uniqueConstraintFound := false for jsonName, value := range jsonTagToValue { if jsonToReverseEdgeTags[jsonName] != "" { continue } if jsonName == "gid" { + uniqueConstraintFound = true continue } var val *api.Value @@ -98,6 +100,7 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac if jsonToDbTags[jsonName] != nil { constraint := jsonToDbTags[jsonName].constraint if constraint == "unique" || constraint == "term" { + uniqueConstraintFound = true u.Directive = pb.SchemaUpdate_INDEX if constraint == "unique" { u.Tokenizer = []string{"exact"} @@ -110,6 +113,9 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac sch.Preds = append(sch.Preds, u) nquads = append(nquads, nquad) } + if !uniqueConstraintFound { + return fmt.Errorf(NoUniqueConstr, t.Name()) + } sch.Types = append(sch.Types, &pb.TypeUpdate{ TypeName: addNamespace(n.id, t.Name()), Fields: sch.Preds, @@ -392,7 +398,7 @@ func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { } } - return 0, nil, fmt.Errorf("unique constraint not defined for any field on type %s", t.Name()) + return 0, nil, fmt.Errorf(NoUniqueConstr, t.Name()) } func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { diff --git a/api_test.go b/api_test.go index 558ace9..1d3293c 100644 --- a/api_test.go +++ b/api_test.go @@ -2,6 +2,7 @@ package modusdb_test import ( "context" + "fmt" "testing" "github.com/stretchr/testify/require" @@ -464,3 +465,50 @@ func TestLinkingObjectsByGid(t *testing.T) { require.Equal(t, "B", queriedBranch.Name) } + +type BadProject struct { + Name string `json:"name,omitempty"` + ClerkId string `json:"clerk_id,omitempty"` +} + +type BadBranch struct { + Gid uint64 `json:"gid,omitempty"` + Name string `json:"name,omitempty"` + ClerkId string `json:"clerk_id,omitempty" db:"constraint=unique"` + Proj BadProject `json:"proj,omitempty"` +} + +func TestNestedObjectMutationWithBadType(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + branch := &BadBranch{ + Name: "B", + ClerkId: "123", + Proj: BadProject{ + Name: "P", + ClerkId: "456", + }, + } + + _, _, err = modusdb.Create(db, branch, db1.ID()) + require.Error(t, err) + require.Equal(t, fmt.Sprintf(modusdb.NoUniqueConstr, "BadProject"), err.Error()) + + proj := &BadProject{ + Name: "P", + ClerkId: "456", + } + + _, _, err = modusdb.Create(db, proj, db1.ID()) + require.Error(t, err) + require.Equal(t, fmt.Sprintf(modusdb.NoUniqueConstr, "BadProject"), err.Error()) + +} diff --git a/api_types.go b/api_types.go index 6d3dc98..811a2ff 100644 --- a/api_types.go +++ b/api_types.go @@ -14,7 +14,8 @@ import ( ) var ( - ErrNoObjFound = fmt.Errorf("no object found") + ErrNoObjFound = fmt.Errorf("no object found") + NoUniqueConstr = "unique constraint not defined for any field on type %s" ) type UniqueField interface { From 2d6978e72fc5febac647e690116ced7a58252f71 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Wed, 25 Dec 2024 17:31:28 -0800 Subject: [PATCH 11/12] cleanup (#37) --- api_dql.go | 41 ++++ api_helper.go | 456 ------------------------------------------- api_mutate_helper.go | 240 +++++++++++++++++++++++ api_query_helper.go | 123 ++++++++++++ api_reflect.go | 33 ++++ api_types.go | 8 +- 6 files changed, 443 insertions(+), 458 deletions(-) create mode 100644 api_dql.go delete mode 100644 api_helper.go create mode 100644 api_mutate_helper.go create mode 100644 api_query_helper.go diff --git a/api_dql.go b/api_dql.go new file mode 100644 index 0000000..e7d3c38 --- /dev/null +++ b/api_dql.go @@ -0,0 +1,41 @@ +package modusdb + +import "fmt" + +type QueryFunc func() string + +const ( + objQuery = ` + { + obj(%s) { + uid + expand(_all_) { + uid + expand(_all_) + dgraph.type + } + dgraph.type + %s + } + } + ` + + funcUid = `func: uid(%d)` + funcEq = `func: eq(%s, %s)` +) + +func buildUidQuery(gid uint64) QueryFunc { + return func() string { + return fmt.Sprintf("func: uid(%d)", gid) + } +} + +func buildEqQuery(key, value any) QueryFunc { + return func() string { + return fmt.Sprintf("func: eq(%s, %s)", key, value) + } +} + +func formatObjQuery(qf QueryFunc, extraFields string) string { + return fmt.Sprintf(objQuery, qf(), extraFields) +} diff --git a/api_helper.go b/api_helper.go deleted file mode 100644 index 0314d08..0000000 --- a/api_helper.go +++ /dev/null @@ -1,456 +0,0 @@ -package modusdb - -import ( - "context" - "encoding/json" - "fmt" - "reflect" - - "github.com/dgraph-io/dgo/v240/protos/api" - "github.com/dgraph-io/dgraph/v24/dql" - "github.com/dgraph-io/dgraph/v24/protos/pb" - "github.com/dgraph-io/dgraph/v24/query" - "github.com/dgraph-io/dgraph/v24/schema" - "github.com/dgraph-io/dgraph/v24/worker" - "github.com/dgraph-io/dgraph/v24/x" -) - -func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object T, - gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error { - t := reflect.TypeOf(object) - if t.Kind() != reflect.Struct { - return fmt.Errorf("expected struct, got %s", t.Kind()) - } - - fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, err := getFieldTags(t) - if err != nil { - return err - } - jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) - - nquads := make([]*api.NQuad, 0) - uniqueConstraintFound := false - for jsonName, value := range jsonTagToValue { - if jsonToReverseEdgeTags[jsonName] != "" { - continue - } - if jsonName == "gid" { - uniqueConstraintFound = true - continue - } - var val *api.Value - var valType pb.Posting_ValType - - reflectValueType := reflect.TypeOf(value) - var nquad *api.NQuad - if reflectValueType.Kind() == reflect.Struct { - value = reflect.ValueOf(value).Interface() - newGid, err := getUidOrMutate(ctx, n.db, n, value) - if err != nil { - return err - } - valType = pb.Posting_UID - - nquad = &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: getPredicateName(t.Name(), jsonName), - ObjectId: fmt.Sprint(newGid), - } - } else if reflectValueType.Kind() == reflect.Pointer { - // dereference the pointer - reflectValueType = reflectValueType.Elem() - if reflectValueType.Kind() == reflect.Struct { - // convert value to pointer, and then dereference - value = reflect.ValueOf(value).Elem().Interface() - newGid, err := getUidOrMutate(ctx, n.db, n, value) - if err != nil { - return err - } - valType = pb.Posting_UID - - nquad = &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: getPredicateName(t.Name(), jsonName), - ObjectId: fmt.Sprint(newGid), - } - } - } else { - valType, err = valueToPosting_ValType(value) - if err != nil { - return err - } - val, err = valueToApiVal(value) - if err != nil { - return err - } - - nquad = &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: getPredicateName(t.Name(), jsonName), - ObjectValue: val, - } - } - u := &pb.SchemaUpdate{ - Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), - ValueType: valType, - } - if jsonToDbTags[jsonName] != nil { - constraint := jsonToDbTags[jsonName].constraint - if constraint == "unique" || constraint == "term" { - uniqueConstraintFound = true - u.Directive = pb.SchemaUpdate_INDEX - if constraint == "unique" { - u.Tokenizer = []string{"exact"} - } else { - u.Tokenizer = []string{"term"} - } - } - } - - sch.Preds = append(sch.Preds, u) - nquads = append(nquads, nquad) - } - if !uniqueConstraintFound { - return fmt.Errorf(NoUniqueConstr, t.Name()) - } - sch.Types = append(sch.Types, &pb.TypeUpdate{ - TypeName: addNamespace(n.id, t.Name()), - Fields: sch.Preds, - }) - - val, err := valueToApiVal(t.Name()) - if err != nil { - return err - } - nquad := &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: "dgraph.type", - ObjectValue: val, - } - nquads = append(nquads, nquad) - - *dms = append(*dms, &dql.Mutation{ - Set: nquads, - }) - - return nil -} - -func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { - return []*dql.Mutation{{ - Del: []*api.NQuad{ - { - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: x.Star, - ObjectValue: &api.Value{ - Val: &api.Value_DefaultVal{DefaultVal: x.Star}, - }, - }, - }, - }} -} - -func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, error) { - query := ` - { - obj(func: uid(%d)) { - uid - expand(_all_) { - uid - expand(_all_) - dgraph.type - } - dgraph.type - %s - } - } - ` - - return executeGet[T](ctx, n, query, gid) -} - -func getByGidWithObject[T any](ctx context.Context, n *Namespace, gid uint64, obj T) (uint64, *T, error) { - query := ` - { - obj(func: uid(%d)) { - uid - expand(_all_) { - uid - expand(_all_) - dgraph.type - } - dgraph.type - %s - } - } - ` - - return executeGetWithObject[T](ctx, n, query, obj, false, gid) -} - -func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) { - query := ` - { - obj(func: eq(%s, %s)) { - uid - expand(_all_) { - uid - expand(_all_) - dgraph.type - } - dgraph.type - %s - } - } - ` - - return executeGet[T](ctx, n, query, cf) -} - -func getByConstrainedFieldWithObject[T any](ctx context.Context, n *Namespace, - cf ConstrainedField, obj T) (uint64, *T, error) { - - query := ` - { - obj(func: eq(%s, %s)) { - uid - expand(_all_) { - uid - expand(_all_) - dgraph.type - } - dgraph.type - %s - } - } - ` - - return executeGetWithObject[T](ctx, n, query, obj, false, cf) -} - -func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, query string, args ...R) (uint64, *T, error) { - if len(args) != 1 { - return 0, nil, fmt.Errorf("expected 1 argument, got %d", len(args)) - } - - var obj T - - return executeGetWithObject(ctx, n, query, obj, true, args...) -} - -func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespace, query string, - obj T, withReverse bool, args ...R) (uint64, *T, error) { - t := reflect.TypeOf(obj) - - fieldToJsonTags, jsonToDbTag, jsonToReverseEdgeTags, err := getFieldTags(t) - if err != nil { - return 0, nil, err - } - readFromQuery := "" - if withReverse { - for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { - readFromQuery += fmt.Sprintf(` - %s: ~%s { - uid - expand(_all_) - dgraph.type - } - `, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) - } - } - - var cf ConstrainedField - gid, ok := any(args[0]).(uint64) - if ok { - query = fmt.Sprintf(query, gid, readFromQuery) - } else if cf, ok = any(args[0]).(ConstrainedField); ok { - query = fmt.Sprintf(query, getPredicateName(t.Name(), cf.Key), cf.Value, readFromQuery) - } else { - return 0, nil, fmt.Errorf("invalid unique field type") - } - - if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" { - return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key) - } - - resp, err := n.queryWithLock(ctx, query) - if err != nil { - return 0, nil, err - } - - dynamicType := createDynamicStruct(t, fieldToJsonTags, 1) - - dynamicInstance := reflect.New(dynamicType).Interface() - - var result struct { - Obj []any `json:"obj"` - } - - result.Obj = append(result.Obj, dynamicInstance) - - // Unmarshal the JSON response into the dynamic struct - if err := json.Unmarshal(resp.Json, &result); err != nil { - return 0, nil, err - } - - // Check if we have at least one object in the response - if len(result.Obj) == 0 { - return 0, nil, ErrNoObjFound - } - - // Map the dynamic struct to the final type T - finalObject := reflect.New(t).Interface() - gid, err = mapDynamicToFinal(result.Obj[0], finalObject) - if err != nil { - return 0, nil, err - } - - // Convert to *interface{} then to *T - if ifacePtr, ok := finalObject.(*interface{}); ok { - if typedPtr, ok := (*ifacePtr).(*T); ok { - return gid, typedPtr, nil - } - } - - // If conversion fails, try direct conversion - if typedPtr, ok := finalObject.(*T); ok { - return gid, typedPtr, nil - } - - if dirType, ok := finalObject.(T); ok { - return gid, &dirType, nil - } - - return 0, nil, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) -} - -func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { - edges, err := query.ToDirectedEdges(dms, nil) - if err != nil { - return err - } - - if !db.isOpen { - return ErrClosedDB - } - - startTs, err := db.z.nextTs() - if err != nil { - return err - } - commitTs, err := db.z.nextTs() - if err != nil { - return err - } - - m := &pb.Mutations{ - GroupId: 1, - StartTs: startTs, - Edges: edges, - } - m.Edges, err = query.ExpandEdges(ctx, m) - if err != nil { - return fmt.Errorf("error expanding edges: %w", err) - } - - p := &pb.Proposal{Mutations: m, StartTs: startTs} - if err := worker.ApplyMutations(ctx, p); err != nil { - return err - } - - return worker.ApplyCommited(ctx, &pb.OracleDelta{ - Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, - }) -} - -func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { - t := reflect.TypeOf(object) - fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) - if err != nil { - return 0, nil, err - } - jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) - - for jsonName, value := range jsonTagToValue { - if jsonName == "gid" { - gid, ok := value.(uint64) - if !ok { - continue - } - if gid != 0 { - return gid, nil, nil - } - } - if jsonToDbTags[jsonName] != nil && jsonToDbTags[jsonName].constraint == "unique" { - // check if value is zero or nil - if value == reflect.Zero(reflect.TypeOf(value)).Interface() || value == nil { - continue - } - return 0, &ConstrainedField{ - Key: jsonName, - Value: value, - }, nil - } - } - - return 0, nil, fmt.Errorf(NoUniqueConstr, t.Name()) -} - -func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { - gid, cf, err := getUniqueConstraint[T](object) - if err != nil { - return 0, err - } - - dms := make([]*dql.Mutation, 0) - sch := &schema.ParsedSchema{} - err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) - if err != nil { - return 0, err - } - - err = n.alterSchemaWithParsed(ctx, sch) - if err != nil { - return 0, err - } - if gid != 0 { - gid, _, err = getByGidWithObject[T](ctx, n, gid, object) - if err != nil && err != ErrNoObjFound { - return 0, err - } - if err == nil { - return gid, nil - } - } else if cf != nil { - gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) - if err != nil && err != ErrNoObjFound { - return 0, err - } - if err == nil { - return gid, nil - } - } - - gid, err = db.z.nextUID() - if err != nil { - return 0, err - } - - dms = make([]*dql.Mutation, 0) - err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) - if err != nil { - return 0, err - } - - err = applyDqlMutations(ctx, db, dms) - if err != nil { - return 0, err - } - - return gid, nil -} diff --git a/api_mutate_helper.go b/api_mutate_helper.go new file mode 100644 index 0000000..7957e0a --- /dev/null +++ b/api_mutate_helper.go @@ -0,0 +1,240 @@ +package modusdb + +import ( + "context" + "fmt" + "reflect" + + "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/dgraph-io/dgraph/v24/dql" + "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/query" + "github.com/dgraph-io/dgraph/v24/schema" + "github.com/dgraph-io/dgraph/v24/worker" + "github.com/dgraph-io/dgraph/v24/x" +) + +func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object T, + gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error { + t := reflect.TypeOf(object) + if t.Kind() != reflect.Struct { + return fmt.Errorf("expected struct, got %s", t.Kind()) + } + + fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, err := getFieldTags(t) + if err != nil { + return err + } + jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) + + nquads := make([]*api.NQuad, 0) + uniqueConstraintFound := false + for jsonName, value := range jsonTagToValue { + if jsonToReverseEdgeTags[jsonName] != "" { + continue + } + if jsonName == "gid" { + uniqueConstraintFound = true + continue + } + var val *api.Value + var valType pb.Posting_ValType + + reflectValueType := reflect.TypeOf(value) + var nquad *api.NQuad + + if reflectValueType.Kind() == reflect.Struct { + value = reflect.ValueOf(value).Interface() + newGid, err := getUidOrMutate(ctx, n.db, n, value) + if err != nil { + return err + } + value = newGid + } else if reflectValueType.Kind() == reflect.Pointer { + // dereference the pointer + reflectValueType = reflectValueType.Elem() + if reflectValueType.Kind() == reflect.Struct { + // convert value to pointer, and then dereference + value = reflect.ValueOf(value).Elem().Interface() + newGid, err := getUidOrMutate(ctx, n.db, n, value) + if err != nil { + return err + } + value = newGid + } + } + valType, err = valueToPosting_ValType(value) + if err != nil { + return err + } + val, err = valueToApiVal(value) + if err != nil { + return err + } + + nquad = &api.NQuad{ + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: getPredicateName(t.Name(), jsonName), + } + + if valType == pb.Posting_UID { + nquad.ObjectId = fmt.Sprint(value) + } else { + nquad.ObjectValue = val + } + + u := &pb.SchemaUpdate{ + Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), + ValueType: valType, + } + if jsonToDbTags[jsonName] != nil { + constraint := jsonToDbTags[jsonName].constraint + if constraint == "unique" || constraint == "term" { + uniqueConstraintFound = true + u.Directive = pb.SchemaUpdate_INDEX + if constraint == "unique" { + u.Tokenizer = []string{"exact"} + } else { + u.Tokenizer = []string{"term"} + } + } + } + + sch.Preds = append(sch.Preds, u) + nquads = append(nquads, nquad) + } + if !uniqueConstraintFound { + return fmt.Errorf(NoUniqueConstr, t.Name()) + } + sch.Types = append(sch.Types, &pb.TypeUpdate{ + TypeName: addNamespace(n.id, t.Name()), + Fields: sch.Preds, + }) + + val, err := valueToApiVal(t.Name()) + if err != nil { + return err + } + typeNquad := &api.NQuad{ + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: "dgraph.type", + ObjectValue: val, + } + nquads = append(nquads, typeNquad) + + *dms = append(*dms, &dql.Mutation{ + Set: nquads, + }) + + return nil +} + +func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { + return []*dql.Mutation{{ + Del: []*api.NQuad{ + { + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: x.Star, + ObjectValue: &api.Value{ + Val: &api.Value_DefaultVal{DefaultVal: x.Star}, + }, + }, + }, + }} +} + +func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { + edges, err := query.ToDirectedEdges(dms, nil) + if err != nil { + return err + } + + if !db.isOpen { + return ErrClosedDB + } + + startTs, err := db.z.nextTs() + if err != nil { + return err + } + commitTs, err := db.z.nextTs() + if err != nil { + return err + } + + m := &pb.Mutations{ + GroupId: 1, + StartTs: startTs, + Edges: edges, + } + m.Edges, err = query.ExpandEdges(ctx, m) + if err != nil { + return fmt.Errorf("error expanding edges: %w", err) + } + + p := &pb.Proposal{Mutations: m, StartTs: startTs} + if err := worker.ApplyMutations(ctx, p); err != nil { + return err + } + + return worker.ApplyCommited(ctx, &pb.OracleDelta{ + Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, + }) +} + +func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { + gid, cf, err := getUniqueConstraint[T](object) + if err != nil { + return 0, err + } + + dms := make([]*dql.Mutation, 0) + sch := &schema.ParsedSchema{} + err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + if err != nil { + return 0, err + } + + err = n.alterSchemaWithParsed(ctx, sch) + if err != nil { + return 0, err + } + if gid != 0 { + gid, _, err = getByGidWithObject[T](ctx, n, gid, object) + if err != nil && err != ErrNoObjFound { + return 0, err + } + if err == nil { + return gid, nil + } + } else if cf != nil { + gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) + if err != nil && err != ErrNoObjFound { + return 0, err + } + if err == nil { + return gid, nil + } + } + + gid, err = db.z.nextUID() + if err != nil { + return 0, err + } + + dms = make([]*dql.Mutation, 0) + err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + if err != nil { + return 0, err + } + + err = applyDqlMutations(ctx, db, dms) + if err != nil { + return 0, err + } + + return gid, nil +} diff --git a/api_query_helper.go b/api_query_helper.go new file mode 100644 index 0000000..e62d8f3 --- /dev/null +++ b/api_query_helper.go @@ -0,0 +1,123 @@ +package modusdb + +import ( + "context" + "encoding/json" + "fmt" + "reflect" +) + +func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, error) { + return executeGet[T](ctx, n, gid) +} + +func getByGidWithObject[T any](ctx context.Context, n *Namespace, gid uint64, obj T) (uint64, *T, error) { + return executeGetWithObject[T](ctx, n, obj, false, gid) +} + +func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) { + return executeGet[T](ctx, n, cf) +} + +func getByConstrainedFieldWithObject[T any](ctx context.Context, n *Namespace, + cf ConstrainedField, obj T) (uint64, *T, error) { + + return executeGetWithObject[T](ctx, n, obj, false, cf) +} + +func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, args ...R) (uint64, *T, error) { + if len(args) != 1 { + return 0, nil, fmt.Errorf("expected 1 argument, got %d", len(args)) + } + + var obj T + + return executeGetWithObject(ctx, n, obj, true, args...) +} + +func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespace, + obj T, withReverse bool, args ...R) (uint64, *T, error) { + t := reflect.TypeOf(obj) + + fieldToJsonTags, jsonToDbTag, jsonToReverseEdgeTags, err := getFieldTags(t) + if err != nil { + return 0, nil, err + } + readFromQuery := "" + if withReverse { + for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { + readFromQuery += fmt.Sprintf(` + %s: ~%s { + uid + expand(_all_) + dgraph.type + } + `, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + } + } + + var cf ConstrainedField + var query string + gid, ok := any(args[0]).(uint64) + if ok { + query = formatObjQuery(buildUidQuery(gid), readFromQuery) + } else if cf, ok = any(args[0]).(ConstrainedField); ok { + query = formatObjQuery(buildEqQuery(getPredicateName(t.Name(), cf.Key), cf.Value), readFromQuery) + } else { + return 0, nil, fmt.Errorf("invalid unique field type") + } + + if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" { + return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key) + } + + resp, err := n.queryWithLock(ctx, query) + if err != nil { + return 0, nil, err + } + + dynamicType := createDynamicStruct(t, fieldToJsonTags, 1) + + dynamicInstance := reflect.New(dynamicType).Interface() + + var result struct { + Obj []any `json:"obj"` + } + + result.Obj = append(result.Obj, dynamicInstance) + + // Unmarshal the JSON response into the dynamic struct + if err := json.Unmarshal(resp.Json, &result); err != nil { + return 0, nil, err + } + + // Check if we have at least one object in the response + if len(result.Obj) == 0 { + return 0, nil, ErrNoObjFound + } + + // Map the dynamic struct to the final type T + finalObject := reflect.New(t).Interface() + gid, err = mapDynamicToFinal(result.Obj[0], finalObject) + if err != nil { + return 0, nil, err + } + + // Convert to *interface{} then to *T + if ifacePtr, ok := finalObject.(*interface{}); ok { + if typedPtr, ok := (*ifacePtr).(*T); ok { + return gid, typedPtr, nil + } + } + + // If conversion fails, try direct conversion + if typedPtr, ok := finalObject.(*T); ok { + return gid, typedPtr, nil + } + + if dirType, ok := finalObject.(T); ok { + return gid, &dirType, nil + } + + return 0, nil, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) +} diff --git a/api_reflect.go b/api_reflect.go index 798b9d4..f74cd53 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -164,3 +164,36 @@ func mapDynamicToFinal(dynamic any, final any) (uint64, error) { } return gid, nil } + +func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { + t := reflect.TypeOf(object) + fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) + if err != nil { + return 0, nil, err + } + jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) + + for jsonName, value := range jsonTagToValue { + if jsonName == "gid" { + gid, ok := value.(uint64) + if !ok { + continue + } + if gid != 0 { + return gid, nil, nil + } + } + if jsonToDbTags[jsonName] != nil && jsonToDbTags[jsonName].constraint == "unique" { + // check if value is zero or nil + if value == reflect.Zero(reflect.TypeOf(value)).Interface() || value == nil { + continue + } + return 0, &ConstrainedField{ + Key: jsonName, + Value: value, + }, nil + } + } + + return 0, nil, fmt.Errorf(NoUniqueConstr, t.Name()) +} diff --git a/api_types.go b/api_types.go index 811a2ff..860edda 100644 --- a/api_types.go +++ b/api_types.go @@ -61,8 +61,10 @@ func valueToPosting_ValType(v any) (pb.Posting_ValType, error) { switch v.(type) { case string: return pb.Posting_STRING, nil - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32: return pb.Posting_INT, nil + case uint64: + return pb.Posting_UID, nil case bool: return pb.Posting_BOOL, nil case float32, float64: @@ -100,6 +102,8 @@ func valueToApiVal(v any) (*api.Value, error) { return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil case uint32: return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint64: + return &api.Value{Val: &api.Value_UidVal{UidVal: val}}, nil case bool: return &api.Value{Val: &api.Value_BoolVal{BoolVal: val}}, nil case float32: @@ -120,7 +124,7 @@ func valueToApiVal(v any) (*api.Value, error) { return nil, err } return &api.Value{Val: &api.Value_GeoVal{GeoVal: bytes}}, nil - case uint, uint64: + case uint: return &api.Value{Val: &api.Value_DefaultVal{DefaultVal: fmt.Sprint(v)}}, nil default: return nil, fmt.Errorf("unsupported type %T", v) From 77aaa886d91ec37e85c44402bae419a1d422cd6c Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Wed, 25 Dec 2024 17:32:30 -0800 Subject: [PATCH 12/12] lint --- api_dql.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api_dql.go b/api_dql.go index e7d3c38..a2b0e65 100644 --- a/api_dql.go +++ b/api_dql.go @@ -26,13 +26,13 @@ const ( func buildUidQuery(gid uint64) QueryFunc { return func() string { - return fmt.Sprintf("func: uid(%d)", gid) + return fmt.Sprintf(funcUid, gid) } } func buildEqQuery(key, value any) QueryFunc { return func() string { - return fmt.Sprintf("func: eq(%s, %s)", key, value) + return fmt.Sprintf(funcEq, key, value) } }