From b1e7baae65f6f455ecae6874c3c85bbc9fc1eef8 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Wed, 25 Dec 2024 17:25:29 -0800 Subject: [PATCH 1/2] cleanup --- api_dql.go | 55 ++++++++++++++++++++ api_helper.go | 137 +++++++++++++------------------------------------- api_types.go | 8 ++- 3 files changed, 96 insertions(+), 104 deletions(-) create mode 100644 api_dql.go diff --git a/api_dql.go b/api_dql.go new file mode 100644 index 0000000..6999a3b --- /dev/null +++ b/api_dql.go @@ -0,0 +1,55 @@ +package modusdb + +import "fmt" + +type QueryFunc func() string + +const ( + objQuery = ` + { + obj(%s) { + uid + expand(_all_) { + uid + expand(_all_) + dgraph.type + } + dgraph.type + %s + } + } + ` + unstructuredQuery = ` + { + obj(%s) { + uid + expand(_all_) { + uid + expand(_all_) + } + } + } + ` + funcUid = `func: uid(%d)` + funcEq = `func: eq(%s, %s)` +) + +func buildUidQuery(gid uint64) QueryFunc { + return func() string { + return fmt.Sprintf("func: uid(%d)", gid) + } +} + +func buildEqQuery(key, value any) QueryFunc { + return func() string { + return fmt.Sprintf("func: eq(%s, %s)", key, value) + } +} + +func formatObjQuery(qf QueryFunc, extraFields string) string { + return fmt.Sprintf(objQuery, qf(), extraFields) +} + +func formatUnstructuredQuery(qf QueryFunc) string { + return fmt.Sprintf(unstructuredQuery, qf()) +} diff --git a/api_helper.go b/api_helper.go index 0314d08..c2c25b4 100644 --- a/api_helper.go +++ b/api_helper.go @@ -43,20 +43,14 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac reflectValueType := reflect.TypeOf(value) var nquad *api.NQuad + if reflectValueType.Kind() == reflect.Struct { value = reflect.ValueOf(value).Interface() newGid, err := getUidOrMutate(ctx, n.db, n, value) if err != nil { return err } - valType = pb.Posting_UID - - nquad = &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: getPredicateName(t.Name(), jsonName), - ObjectId: fmt.Sprint(newGid), - } + value = newGid } else if reflectValueType.Kind() == reflect.Pointer { // dereference the pointer reflectValueType = reflectValueType.Elem() @@ -67,32 +61,30 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac if err != nil { return err } - valType = pb.Posting_UID - - nquad = &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: getPredicateName(t.Name(), jsonName), - ObjectId: fmt.Sprint(newGid), - } - } - } else { - valType, err = valueToPosting_ValType(value) - if err != nil { - return err - } - val, err = valueToApiVal(value) - if err != nil { - return err + value = newGid } + } + valType, err = valueToPosting_ValType(value) + if err != nil { + return err + } + val, err = valueToApiVal(value) + if err != nil { + return err + } - nquad = &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: getPredicateName(t.Name(), jsonName), - ObjectValue: val, - } + nquad = &api.NQuad{ + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: getPredicateName(t.Name(), jsonName), } + + if valType == pb.Posting_UID { + nquad.ObjectId = fmt.Sprint(value) + } else { + nquad.ObjectValue = val + } + u := &pb.SchemaUpdate{ Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), ValueType: valType, @@ -125,13 +117,13 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac if err != nil { return err } - nquad := &api.NQuad{ + typeNquad := &api.NQuad{ Namespace: n.ID(), Subject: fmt.Sprint(gid), Predicate: "dgraph.type", ObjectValue: val, } - nquads = append(nquads, nquad) + nquads = append(nquads, typeNquad) *dms = append(*dms, &dql.Mutation{ Set: nquads, @@ -156,94 +148,34 @@ func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { } func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, error) { - query := ` - { - obj(func: uid(%d)) { - uid - expand(_all_) { - uid - expand(_all_) - dgraph.type - } - dgraph.type - %s - } - } - ` - - return executeGet[T](ctx, n, query, gid) + return executeGet[T](ctx, n, gid) } func getByGidWithObject[T any](ctx context.Context, n *Namespace, gid uint64, obj T) (uint64, *T, error) { - query := ` - { - obj(func: uid(%d)) { - uid - expand(_all_) { - uid - expand(_all_) - dgraph.type - } - dgraph.type - %s - } - } - ` - - return executeGetWithObject[T](ctx, n, query, obj, false, gid) + return executeGetWithObject[T](ctx, n, obj, false, gid) } func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) { - query := ` - { - obj(func: eq(%s, %s)) { - uid - expand(_all_) { - uid - expand(_all_) - dgraph.type - } - dgraph.type - %s - } - } - ` - - return executeGet[T](ctx, n, query, cf) + return executeGet[T](ctx, n, cf) } func getByConstrainedFieldWithObject[T any](ctx context.Context, n *Namespace, cf ConstrainedField, obj T) (uint64, *T, error) { - query := ` - { - obj(func: eq(%s, %s)) { - uid - expand(_all_) { - uid - expand(_all_) - dgraph.type - } - dgraph.type - %s - } - } - ` - - return executeGetWithObject[T](ctx, n, query, obj, false, cf) + return executeGetWithObject[T](ctx, n, obj, false, cf) } -func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, query string, args ...R) (uint64, *T, error) { +func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, args ...R) (uint64, *T, error) { if len(args) != 1 { return 0, nil, fmt.Errorf("expected 1 argument, got %d", len(args)) } var obj T - return executeGetWithObject(ctx, n, query, obj, true, args...) + return executeGetWithObject(ctx, n, obj, true, args...) } -func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespace, query string, +func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespace, obj T, withReverse bool, args ...R) (uint64, *T, error) { t := reflect.TypeOf(obj) @@ -265,11 +197,12 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac } var cf ConstrainedField + var query string gid, ok := any(args[0]).(uint64) if ok { - query = fmt.Sprintf(query, gid, readFromQuery) + query = formatObjQuery(buildUidQuery(gid), readFromQuery) } else if cf, ok = any(args[0]).(ConstrainedField); ok { - query = fmt.Sprintf(query, getPredicateName(t.Name(), cf.Key), cf.Value, readFromQuery) + query = formatObjQuery(buildEqQuery(getPredicateName(t.Name(), cf.Key), cf.Value), readFromQuery) } else { return 0, nil, fmt.Errorf("invalid unique field type") } diff --git a/api_types.go b/api_types.go index 811a2ff..860edda 100644 --- a/api_types.go +++ b/api_types.go @@ -61,8 +61,10 @@ func valueToPosting_ValType(v any) (pb.Posting_ValType, error) { switch v.(type) { case string: return pb.Posting_STRING, nil - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32: return pb.Posting_INT, nil + case uint64: + return pb.Posting_UID, nil case bool: return pb.Posting_BOOL, nil case float32, float64: @@ -100,6 +102,8 @@ func valueToApiVal(v any) (*api.Value, error) { return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil case uint32: return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint64: + return &api.Value{Val: &api.Value_UidVal{UidVal: val}}, nil case bool: return &api.Value{Val: &api.Value_BoolVal{BoolVal: val}}, nil case float32: @@ -120,7 +124,7 @@ func valueToApiVal(v any) (*api.Value, error) { return nil, err } return &api.Value{Val: &api.Value_GeoVal{GeoVal: bytes}}, nil - case uint, uint64: + case uint: return &api.Value{Val: &api.Value_DefaultVal{DefaultVal: fmt.Sprint(v)}}, nil default: return nil, fmt.Errorf("unsupported type %T", v) From 29b8de85acfaf0942732815dbba882628361d8c3 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Wed, 25 Dec 2024 17:31:07 -0800 Subject: [PATCH 2/2] refactor a little --- api_dql.go | 16 +-- api_helper.go => api_mutate_helper.go | 149 -------------------------- api_query_helper.go | 123 +++++++++++++++++++++ api_reflect.go | 33 ++++++ 4 files changed, 157 insertions(+), 164 deletions(-) rename api_helper.go => api_mutate_helper.go (57%) create mode 100644 api_query_helper.go diff --git a/api_dql.go b/api_dql.go index 6999a3b..e7d3c38 100644 --- a/api_dql.go +++ b/api_dql.go @@ -19,17 +19,7 @@ const ( } } ` - unstructuredQuery = ` - { - obj(%s) { - uid - expand(_all_) { - uid - expand(_all_) - } - } - } - ` + funcUid = `func: uid(%d)` funcEq = `func: eq(%s, %s)` ) @@ -49,7 +39,3 @@ func buildEqQuery(key, value any) QueryFunc { func formatObjQuery(qf QueryFunc, extraFields string) string { return fmt.Sprintf(objQuery, qf(), extraFields) } - -func formatUnstructuredQuery(qf QueryFunc) string { - return fmt.Sprintf(unstructuredQuery, qf()) -} diff --git a/api_helper.go b/api_mutate_helper.go similarity index 57% rename from api_helper.go rename to api_mutate_helper.go index c2c25b4..7957e0a 100644 --- a/api_helper.go +++ b/api_mutate_helper.go @@ -2,7 +2,6 @@ package modusdb import ( "context" - "encoding/json" "fmt" "reflect" @@ -147,121 +146,6 @@ func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { }} } -func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, error) { - return executeGet[T](ctx, n, gid) -} - -func getByGidWithObject[T any](ctx context.Context, n *Namespace, gid uint64, obj T) (uint64, *T, error) { - return executeGetWithObject[T](ctx, n, obj, false, gid) -} - -func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) { - return executeGet[T](ctx, n, cf) -} - -func getByConstrainedFieldWithObject[T any](ctx context.Context, n *Namespace, - cf ConstrainedField, obj T) (uint64, *T, error) { - - return executeGetWithObject[T](ctx, n, obj, false, cf) -} - -func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, args ...R) (uint64, *T, error) { - if len(args) != 1 { - return 0, nil, fmt.Errorf("expected 1 argument, got %d", len(args)) - } - - var obj T - - return executeGetWithObject(ctx, n, obj, true, args...) -} - -func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespace, - obj T, withReverse bool, args ...R) (uint64, *T, error) { - t := reflect.TypeOf(obj) - - fieldToJsonTags, jsonToDbTag, jsonToReverseEdgeTags, err := getFieldTags(t) - if err != nil { - return 0, nil, err - } - readFromQuery := "" - if withReverse { - for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { - readFromQuery += fmt.Sprintf(` - %s: ~%s { - uid - expand(_all_) - dgraph.type - } - `, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) - } - } - - var cf ConstrainedField - var query string - gid, ok := any(args[0]).(uint64) - if ok { - query = formatObjQuery(buildUidQuery(gid), readFromQuery) - } else if cf, ok = any(args[0]).(ConstrainedField); ok { - query = formatObjQuery(buildEqQuery(getPredicateName(t.Name(), cf.Key), cf.Value), readFromQuery) - } else { - return 0, nil, fmt.Errorf("invalid unique field type") - } - - if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" { - return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key) - } - - resp, err := n.queryWithLock(ctx, query) - if err != nil { - return 0, nil, err - } - - dynamicType := createDynamicStruct(t, fieldToJsonTags, 1) - - dynamicInstance := reflect.New(dynamicType).Interface() - - var result struct { - Obj []any `json:"obj"` - } - - result.Obj = append(result.Obj, dynamicInstance) - - // Unmarshal the JSON response into the dynamic struct - if err := json.Unmarshal(resp.Json, &result); err != nil { - return 0, nil, err - } - - // Check if we have at least one object in the response - if len(result.Obj) == 0 { - return 0, nil, ErrNoObjFound - } - - // Map the dynamic struct to the final type T - finalObject := reflect.New(t).Interface() - gid, err = mapDynamicToFinal(result.Obj[0], finalObject) - if err != nil { - return 0, nil, err - } - - // Convert to *interface{} then to *T - if ifacePtr, ok := finalObject.(*interface{}); ok { - if typedPtr, ok := (*ifacePtr).(*T); ok { - return gid, typedPtr, nil - } - } - - // If conversion fails, try direct conversion - if typedPtr, ok := finalObject.(*T); ok { - return gid, typedPtr, nil - } - - if dirType, ok := finalObject.(T); ok { - return gid, &dirType, nil - } - - return 0, nil, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) -} - func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { edges, err := query.ToDirectedEdges(dms, nil) if err != nil { @@ -301,39 +185,6 @@ func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { }) } -func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { - t := reflect.TypeOf(object) - fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) - if err != nil { - return 0, nil, err - } - jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) - - for jsonName, value := range jsonTagToValue { - if jsonName == "gid" { - gid, ok := value.(uint64) - if !ok { - continue - } - if gid != 0 { - return gid, nil, nil - } - } - if jsonToDbTags[jsonName] != nil && jsonToDbTags[jsonName].constraint == "unique" { - // check if value is zero or nil - if value == reflect.Zero(reflect.TypeOf(value)).Interface() || value == nil { - continue - } - return 0, &ConstrainedField{ - Key: jsonName, - Value: value, - }, nil - } - } - - return 0, nil, fmt.Errorf(NoUniqueConstr, t.Name()) -} - func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { gid, cf, err := getUniqueConstraint[T](object) if err != nil { diff --git a/api_query_helper.go b/api_query_helper.go new file mode 100644 index 0000000..e62d8f3 --- /dev/null +++ b/api_query_helper.go @@ -0,0 +1,123 @@ +package modusdb + +import ( + "context" + "encoding/json" + "fmt" + "reflect" +) + +func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, error) { + return executeGet[T](ctx, n, gid) +} + +func getByGidWithObject[T any](ctx context.Context, n *Namespace, gid uint64, obj T) (uint64, *T, error) { + return executeGetWithObject[T](ctx, n, obj, false, gid) +} + +func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) { + return executeGet[T](ctx, n, cf) +} + +func getByConstrainedFieldWithObject[T any](ctx context.Context, n *Namespace, + cf ConstrainedField, obj T) (uint64, *T, error) { + + return executeGetWithObject[T](ctx, n, obj, false, cf) +} + +func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, args ...R) (uint64, *T, error) { + if len(args) != 1 { + return 0, nil, fmt.Errorf("expected 1 argument, got %d", len(args)) + } + + var obj T + + return executeGetWithObject(ctx, n, obj, true, args...) +} + +func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespace, + obj T, withReverse bool, args ...R) (uint64, *T, error) { + t := reflect.TypeOf(obj) + + fieldToJsonTags, jsonToDbTag, jsonToReverseEdgeTags, err := getFieldTags(t) + if err != nil { + return 0, nil, err + } + readFromQuery := "" + if withReverse { + for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { + readFromQuery += fmt.Sprintf(` + %s: ~%s { + uid + expand(_all_) + dgraph.type + } + `, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + } + } + + var cf ConstrainedField + var query string + gid, ok := any(args[0]).(uint64) + if ok { + query = formatObjQuery(buildUidQuery(gid), readFromQuery) + } else if cf, ok = any(args[0]).(ConstrainedField); ok { + query = formatObjQuery(buildEqQuery(getPredicateName(t.Name(), cf.Key), cf.Value), readFromQuery) + } else { + return 0, nil, fmt.Errorf("invalid unique field type") + } + + if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" { + return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key) + } + + resp, err := n.queryWithLock(ctx, query) + if err != nil { + return 0, nil, err + } + + dynamicType := createDynamicStruct(t, fieldToJsonTags, 1) + + dynamicInstance := reflect.New(dynamicType).Interface() + + var result struct { + Obj []any `json:"obj"` + } + + result.Obj = append(result.Obj, dynamicInstance) + + // Unmarshal the JSON response into the dynamic struct + if err := json.Unmarshal(resp.Json, &result); err != nil { + return 0, nil, err + } + + // Check if we have at least one object in the response + if len(result.Obj) == 0 { + return 0, nil, ErrNoObjFound + } + + // Map the dynamic struct to the final type T + finalObject := reflect.New(t).Interface() + gid, err = mapDynamicToFinal(result.Obj[0], finalObject) + if err != nil { + return 0, nil, err + } + + // Convert to *interface{} then to *T + if ifacePtr, ok := finalObject.(*interface{}); ok { + if typedPtr, ok := (*ifacePtr).(*T); ok { + return gid, typedPtr, nil + } + } + + // If conversion fails, try direct conversion + if typedPtr, ok := finalObject.(*T); ok { + return gid, typedPtr, nil + } + + if dirType, ok := finalObject.(T); ok { + return gid, &dirType, nil + } + + return 0, nil, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) +} diff --git a/api_reflect.go b/api_reflect.go index 798b9d4..f74cd53 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -164,3 +164,36 @@ func mapDynamicToFinal(dynamic any, final any) (uint64, error) { } return gid, nil } + +func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { + t := reflect.TypeOf(object) + fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) + if err != nil { + return 0, nil, err + } + jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) + + for jsonName, value := range jsonTagToValue { + if jsonName == "gid" { + gid, ok := value.(uint64) + if !ok { + continue + } + if gid != 0 { + return gid, nil, nil + } + } + if jsonToDbTags[jsonName] != nil && jsonToDbTags[jsonName].constraint == "unique" { + // check if value is zero or nil + if value == reflect.Zero(reflect.TypeOf(value)).Interface() || value == nil { + continue + } + return 0, &ConstrainedField{ + Key: jsonName, + Value: value, + }, nil + } + } + + return 0, nil, fmt.Errorf(NoUniqueConstr, t.Name()) +}