From 70a3d51989edf43efa782f588f10b6cf4e6ff8d0 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Thu, 2 Jan 2025 20:06:56 -0800 Subject: [PATCH 01/17] add reverse edge support in APIs --- api_dql.go | 8 +++--- api_mutate_helper.go | 37 ++++++++++++++++++------ api_query_helper.go | 2 +- api_reflect.go | 35 +++++++++++++++++------ api_test.go | 67 +++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 123 insertions(+), 26 deletions(-) diff --git a/api_dql.go b/api_dql.go index 203440a..b06a836 100644 --- a/api_dql.go +++ b/api_dql.go @@ -21,9 +21,9 @@ const ( objQuery = ` { obj(func: %s) { - uid + gid: uid expand(_all_) { - uid + gid: uid expand(_all_) dgraph.type } @@ -36,9 +36,9 @@ const ( objsQuery = ` { objs(func: type("%s")%s) @filter(%s) { - uid + gid: uid expand(_all_) { - uid + gid: uid expand(_all_) dgraph.type } diff --git a/api_mutate_helper.go b/api_mutate_helper.go index 2dd4188..f90c258 100644 --- a/api_mutate_helper.go +++ b/api_mutate_helper.go @@ -13,6 +13,7 @@ import ( "context" "fmt" "reflect" + "strings" "github.com/dgraph-io/dgo/v240/protos/api" "github.com/dgraph-io/dgraph/v24/dql" @@ -39,18 +40,34 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac nquads := make([]*api.NQuad, 0) uniqueConstraintFound := false for jsonName, value := range jsonTagToValue { + var val *api.Value + var valType pb.Posting_ValType + + reflectValueType := reflect.TypeOf(value) + var nquad *api.NQuad + if jsonToReverseEdgeTags[jsonName] != "" { + if reflectValueType.Kind() != reflect.Slice || reflectValueType.Elem().Kind() != reflect.Struct { + return fmt.Errorf("reverse edge %s should be a slice of structs", jsonName) + } + reverseEdge := jsonToReverseEdgeTags[jsonName] + typeName := strings.Split(reverseEdge, ".")[0] + u := &pb.SchemaUpdate{ + Predicate: addNamespace(n.id, reverseEdge), + ValueType: pb.Posting_UID, + Directive: pb.SchemaUpdate_REVERSE, + } + sch.Preds = append(sch.Preds, u) + sch.Types = append(sch.Types, &pb.TypeUpdate{ + TypeName: addNamespace(n.id, typeName), + Fields: []*pb.SchemaUpdate{u}, + }) continue } if jsonName == "gid" { uniqueConstraintFound = true continue } - var val *api.Value - var valType pb.Posting_ValType - - reflectValueType := reflect.TypeOf(value) - var nquad *api.NQuad if reflectValueType.Kind() == reflect.Struct { value = reflect.ValueOf(value).Interface() @@ -87,16 +104,18 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac Predicate: getPredicateName(t.Name(), jsonName), } + u := &pb.SchemaUpdate{ + Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), + ValueType: valType, + } + if valType == pb.Posting_UID { nquad.ObjectId = fmt.Sprint(value) + u.Directive = pb.SchemaUpdate_REVERSE } else { nquad.ObjectValue = val } - u := &pb.SchemaUpdate{ - Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), - ValueType: valType, - } if jsonToDbTags[jsonName] != nil { constraint := jsonToDbTags[jsonName].constraint if constraint == "vector" && valType != pb.Posting_VFLOAT { diff --git a/api_query_helper.go b/api_query_helper.go index 5c2552f..21eb056 100644 --- a/api_query_helper.go +++ b/api_query_helper.go @@ -56,7 +56,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { readFromQuery += fmt.Sprintf(` %s: ~%s { - uid + gid: uid expand(_all_) dgraph.type } diff --git a/api_reflect.go b/api_reflect.go index 1779a81..26fe7bf 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -100,6 +100,15 @@ func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, dept Type: reflect.PointerTo(nestedType), Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), }) + } else if field.Type.Kind() == reflect.Slice && + field.Type.Elem().Kind() == reflect.Struct { + nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type.Elem()) + nestedType := createDynamicStruct(field.Type.Elem(), nestedFieldToJsonTags, depth+1) + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: reflect.SliceOf(nestedType), + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) } else { fields = append(fields, reflect.StructField{ Name: field.Name, @@ -111,9 +120,9 @@ func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, dept } } fields = append(fields, reflect.StructField{ - Name: "Uid", + Name: "Gid", Type: reflect.TypeOf(""), - Tag: reflect.StructTag(`json:"uid"`), + Tag: reflect.StructTag(`json:"gid"`), }, reflect.StructField{ Name: "DgraphType", Type: reflect.TypeOf([]string{}), @@ -135,13 +144,13 @@ func mapDynamicToFinal(dynamic any, final any) (uint64, error) { dynamicValue := vDynamic.Field(i) var finalField reflect.Value - if dynamicField.Name == "Uid" { + if dynamicField.Name == "Gid" { finalField = vFinal.FieldByName("Gid") gidStr := dynamicValue.String() gid, _ = strconv.ParseUint(gidStr, 0, 64) } else if dynamicField.Name == "DgraphType" { - fieldArr := dynamicValue.Interface().([]string) - if len(fieldArr) == 0 { + _, ok := dynamicValue.Interface().([]string) + if !ok { return 0, ErrNoObjFound } } else { @@ -159,11 +168,21 @@ func mapDynamicToFinal(dynamic any, final any) (uint64, error) { if err != nil { return 0, err } - + } else if dynamicFieldType.Kind() == reflect.Slice && + dynamicFieldType.Elem().Kind() == reflect.Struct { + for j := 0; j < dynamicValue.Len(); j++ { + sliceElem := dynamicValue.Index(j).Addr().Interface() + finalSliceElem := reflect.New(finalField.Type().Elem()).Elem() + _, err := mapDynamicToFinal(sliceElem, finalSliceElem.Addr().Interface()) + if err != nil { + return 0, err + } + finalField.Set(reflect.Append(finalField, finalSliceElem)) + } } else { if finalField.IsValid() && finalField.CanSet() { - // if field name is uid, convert it to uint64 - if dynamicField.Name == "Uid" { + // if field name is gid, convert it to uint64 + if dynamicField.Name == "Gid" { finalField.SetUint(gid) } else { finalField.Set(dynamicValue) diff --git a/api_test.go b/api_test.go index 0ae4bea..96af941 100644 --- a/api_test.go +++ b/api_test.go @@ -400,10 +400,10 @@ func TestQueryApiWithPaginiationAndSorting(t *testing.T) { } type Project struct { - Gid uint64 `json:"gid,omitempty"` - Name string `json:"name,omitempty"` - ClerkId string `json:"clerk_id,omitempty" db:"constraint=unique"` - // Branches []Branch `json:"branches,omitempty" readFrom:"type=Branch,field=proj"` + Gid uint64 `json:"gid,omitempty"` + Name string `json:"name,omitempty"` + ClerkId string `json:"clerk_id,omitempty" db:"constraint=unique"` + Branches []Branch `json:"branches,omitempty" readFrom:"type=Branch,field=proj"` } type Branch struct { @@ -413,6 +413,65 @@ type Branch struct { Proj Project `json:"proj,omitempty"` } +func TestReverseEdgeQuery(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + projGid, project, err := modusdb.Create(db, Project{ + Name: "P", + ClerkId: "456", + }, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "P", project.Name) + require.Equal(t, project.Gid, projGid) + + branch1 := Branch{ + Name: "B", + ClerkId: "123", + Proj: Project{ + Gid: projGid, + }, + } + + branch1Gid, branch1, err := modusdb.Create(db, branch1, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "B", branch1.Name) + require.Equal(t, branch1.Gid, branch1Gid) + require.Equal(t, projGid, branch1.Proj.Gid) + require.Equal(t, "P", branch1.Proj.Name) + + branch2 := Branch{ + Name: "B2", + ClerkId: "456", + Proj: Project{ + Gid: projGid, + }, + } + + branch2Gid, branch2, err := modusdb.Create(db, branch2, db1.ID()) + require.NoError(t, err) + require.Equal(t, "B2", branch2.Name) + require.Equal(t, branch2.Gid, branch2Gid) + require.Equal(t, projGid, branch2.Proj.Gid) + + getProjGid, queriedProject, err := modusdb.Get[Project](db, projGid, db1.ID()) + require.NoError(t, err) + require.Equal(t, projGid, getProjGid) + require.Equal(t, "P", queriedProject.Name) + require.Len(t, queriedProject.Branches, 2) + require.Equal(t, "B", queriedProject.Branches[0].Name) + require.Equal(t, "B2", queriedProject.Branches[1].Name) +} + func TestNestedObjectMutation(t *testing.T) { ctx := context.Background() db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) From cf3afd15270974234cfc78ec81d39da8c7ca5a31 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Thu, 2 Jan 2025 20:25:03 -0800 Subject: [PATCH 02/17] add support for reverse edges --- api_reflect.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/api_reflect.go b/api_reflect.go index 26fe7bf..6cab22e 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -82,7 +82,7 @@ func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, dept field, _ := t.FieldByName(fieldName) if fieldName != "Gid" { if field.Type.Kind() == reflect.Struct { - if depth <= 2 { + if depth <= 1 { nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type) nestedType := createDynamicStruct(field.Type, nestedFieldToJsonTags, depth+1) fields = append(fields, reflect.StructField{ @@ -149,9 +149,14 @@ func mapDynamicToFinal(dynamic any, final any) (uint64, error) { gidStr := dynamicValue.String() gid, _ = strconv.ParseUint(gidStr, 0, 64) } else if dynamicField.Name == "DgraphType" { - _, ok := dynamicValue.Interface().([]string) - if !ok { - return 0, ErrNoObjFound + fieldArrInterface := dynamicValue.Interface() + fieldArr, ok := fieldArrInterface.([]string) + if ok { + if len(fieldArr) == 0 { + return 0, ErrNoObjFound + } + } else { + return 0, fmt.Errorf("DgraphType field should be an array of strings") } } else { finalField = vFinal.FieldByName(dynamicField.Name) From 7df884a473cbb4154544370d7c6e38effd89c39a Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Thu, 2 Jan 2025 20:28:32 -0800 Subject: [PATCH 03/17] add changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2999085..b6f5041 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## UNRELEASED + +- feat: add readfrom json tag to support reverse edges [#49](https://github.com/hypermodeinc/modusDB/pull/49) + ## 2025-01-02 - Version 0.1.0 Baseline for the changelog. From 51e37ba9eda63ca839760a65819be1b662e51548 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Thu, 2 Jan 2025 20:37:39 -0800 Subject: [PATCH 04/17] add support for reverse edges in query --- api.go | 2 +- api_dql.go | 8 ++++++ api_query_helper.go | 16 ++--------- api_test.go | 67 ++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 77 insertions(+), 16 deletions(-) diff --git a/api.go b/api.go index 379fccd..652ea19 100644 --- a/api.go +++ b/api.go @@ -155,7 +155,7 @@ func Query[T any](db *DB, queryParams QueryParams, ns ...uint64) ([]uint64, []T, return nil, nil, err } - return executeQuery[T](ctx, n, queryParams, false) + return executeQuery[T](ctx, n, queryParams, true) } func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, T, error) { diff --git a/api_dql.go b/api_dql.go index b06a836..6c63171 100644 --- a/api_dql.go +++ b/api_dql.go @@ -48,6 +48,14 @@ const ( } ` + reverseEdgeQuery = ` + %s: ~%s { + gid: uid + expand(_all_) + dgraph.type + } + ` + funcUid = `uid(%d)` funcEq = `eq(%s, %s)` funcSimilarTo = `similar_to(%s, %d, "[%s]")` diff --git a/api_query_helper.go b/api_query_helper.go index 21eb056..43cc022 100644 --- a/api_query_helper.go +++ b/api_query_helper.go @@ -54,13 +54,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac readFromQuery := "" if withReverse { for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { - readFromQuery += fmt.Sprintf(` - %s: ~%s { - gid: uid - expand(_all_) - dgraph.type - } - `, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + readFromQuery += fmt.Sprintf(reverseEdgeQuery, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) } } @@ -152,13 +146,7 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar readFromQuery := "" if withReverse { for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { - readFromQuery += fmt.Sprintf(` - %s: ~%s { - uid - expand(_all_) - dgraph.type - } - `, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + readFromQuery += fmt.Sprintf(reverseEdgeQuery, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) } } diff --git a/api_test.go b/api_test.go index 96af941..cd8f3a8 100644 --- a/api_test.go +++ b/api_test.go @@ -413,7 +413,7 @@ type Branch struct { Proj Project `json:"proj,omitempty"` } -func TestReverseEdgeQuery(t *testing.T) { +func TestReverseEdgeGet(t *testing.T) { ctx := context.Background() db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) require.NoError(t, err) @@ -470,6 +470,71 @@ func TestReverseEdgeQuery(t *testing.T) { require.Len(t, queriedProject.Branches, 2) require.Equal(t, "B", queriedProject.Branches[0].Name) require.Equal(t, "B2", queriedProject.Branches[1].Name) + + queryBranchesGids, queriedBranches, err := modusdb.Query[Branch](db, modusdb.QueryParams{}, db1.ID()) + require.NoError(t, err) + require.Len(t, queriedBranches, 2) + require.Len(t, queryBranchesGids, 2) + require.Equal(t, "B", queriedBranches[0].Name) + require.Equal(t, "B2", queriedBranches[1].Name) + + // If i query a branch and get the project, i shouldn't automatically have access to retrieve data about the branches of that project + require.Len(t, queriedBranches[0].Proj.Branches, 0) +} + +func TestReverseEdgeQuery(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + projects := []Project{ + {Name: "P1", ClerkId: "456"}, + {Name: "P2", ClerkId: "789"}, + } + + branchCounter := 1 + clerkCounter := 100 + + for _, project := range projects { + projGid, project, err := modusdb.Create(db, project, db1.ID()) + require.NoError(t, err) + require.Equal(t, project.Name, project.Name) + require.Equal(t, project.Gid, projGid) + + branches := []Branch{ + {Name: fmt.Sprintf("B%d", branchCounter), ClerkId: fmt.Sprintf("%d", clerkCounter), Proj: Project{Gid: projGid}}, + {Name: fmt.Sprintf("B%d", branchCounter+1), ClerkId: fmt.Sprintf("%d", clerkCounter+1), Proj: Project{Gid: projGid}}, + } + branchCounter += 2 + clerkCounter += 2 + + for _, branch := range branches { + branchGid, branch, err := modusdb.Create(db, branch, db1.ID()) + require.NoError(t, err) + require.Equal(t, branch.Name, branch.Name) + require.Equal(t, branch.Gid, branchGid) + require.Equal(t, projGid, branch.Proj.Gid) + } + } + + queriedProjectsGids, queriedProjects, err := modusdb.Query[Project](db, modusdb.QueryParams{}, db1.ID()) + require.NoError(t, err) + require.Len(t, queriedProjects, 2) + require.Len(t, queriedProjectsGids, 2) + require.Equal(t, "P1", queriedProjects[0].Name) + require.Equal(t, "P2", queriedProjects[1].Name) + require.Len(t, queriedProjects[0].Branches, 2) + require.Len(t, queriedProjects[1].Branches, 2) + require.Equal(t, "B1", queriedProjects[0].Branches[0].Name) + require.Equal(t, "B2", queriedProjects[0].Branches[1].Name) + require.Equal(t, "B3", queriedProjects[1].Branches[0].Name) + require.Equal(t, "B4", queriedProjects[1].Branches[1].Name) } func TestNestedObjectMutation(t *testing.T) { From 402e7a7d26c3dd5bbe0f79d7039c1cdee2b2cb0a Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Thu, 2 Jan 2025 20:38:53 -0800 Subject: [PATCH 05/17] . --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6f5041..4d92b85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,8 @@ ## UNRELEASED -- feat: add readfrom json tag to support reverse edges [#49](https://github.com/hypermodeinc/modusDB/pull/49) +- feat: add readfrom json tag to support reverse edges + [#49](https://github.com/hypermodeinc/modusDB/pull/49) ## 2025-01-02 - Version 0.1.0 From 932dc4f0fa78a0644098bb90982a38378a2e947e Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Thu, 2 Jan 2025 21:20:43 -0800 Subject: [PATCH 06/17] . --- api_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api_test.go b/api_test.go index cd8f3a8..cee7458 100644 --- a/api_test.go +++ b/api_test.go @@ -478,7 +478,7 @@ func TestReverseEdgeGet(t *testing.T) { require.Equal(t, "B", queriedBranches[0].Name) require.Equal(t, "B2", queriedBranches[1].Name) - // If i query a branch and get the project, i shouldn't automatically have access to retrieve data about the branches of that project + // max depth is 2 require.Len(t, queriedBranches[0].Proj.Branches, 0) } From 73c604b64e79f57b3734ae2e85741a7e38896fb0 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 3 Jan 2025 07:46:21 -0800 Subject: [PATCH 07/17] add mutation checks, no-op on mutating a read from field, and when deleted, fix no objects found error --- api_test.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/api_test.go b/api_test.go index cee7458..4a68ce6 100644 --- a/api_test.go +++ b/api_test.go @@ -413,6 +413,34 @@ type Branch struct { Proj Project `json:"proj,omitempty"` } +func TestReverseEdgeMutation(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + projGid, project, err := modusdb.Create(db, Project{ + Name: "P", + ClerkId: "456", + Branches: []Branch{ + {Name: "B", ClerkId: "123"}, + {Name: "B2", ClerkId: "456"}, + }, + }, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "P", project.Name) + require.Equal(t, project.Gid, projGid) + + //modifying a read-only field will be a no-op + require.Len(t, project.Branches, 0) +} + func TestReverseEdgeGet(t *testing.T) { ctx := context.Background() db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) From 61d99f81521937012fbb4dd6bcd857981806b7e0 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 3 Jan 2025 07:47:13 -0800 Subject: [PATCH 08/17] . --- api_query_helper.go | 4 ++-- api_reflect.go | 14 +++++++++----- api_test.go | 37 +++++++++++++------------------------ 3 files changed, 24 insertions(+), 31 deletions(-) diff --git a/api_query_helper.go b/api_query_helper.go index 43cc022..9ec211c 100644 --- a/api_query_helper.go +++ b/api_query_helper.go @@ -100,7 +100,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac // Map the dynamic struct to the final type T finalObject := reflect.New(t).Interface() - gid, err = mapDynamicToFinal(result.Obj[0], finalObject) + gid, err = mapDynamicToFinal(result.Obj[0], finalObject, false) if err != nil { return 0, obj, err } @@ -185,7 +185,7 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar var objs []T for _, obj := range result.Objs { finalObject := reflect.New(t).Interface() - gid, err := mapDynamicToFinal(obj, finalObject) + gid, err := mapDynamicToFinal(obj, finalObject, false) if err != nil { return nil, nil, err } diff --git a/api_reflect.go b/api_reflect.go index 6cab22e..5a32b01 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -131,7 +131,7 @@ func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, dept return reflect.StructOf(fields) } -func mapDynamicToFinal(dynamic any, final any) (uint64, error) { +func mapDynamicToFinal(dynamic any, final any, isNested bool) (uint64, error) { vFinal := reflect.ValueOf(final).Elem() vDynamic := reflect.ValueOf(dynamic).Elem() @@ -153,7 +153,11 @@ func mapDynamicToFinal(dynamic any, final any) (uint64, error) { fieldArr, ok := fieldArrInterface.([]string) if ok { if len(fieldArr) == 0 { - return 0, ErrNoObjFound + if !isNested { + return 0, ErrNoObjFound + } else { + continue + } } } else { return 0, fmt.Errorf("DgraphType field should be an array of strings") @@ -162,14 +166,14 @@ func mapDynamicToFinal(dynamic any, final any) (uint64, error) { finalField = vFinal.FieldByName(dynamicField.Name) } if dynamicFieldType.Kind() == reflect.Struct { - _, err := mapDynamicToFinal(dynamicValue.Addr().Interface(), finalField.Addr().Interface()) + _, err := mapDynamicToFinal(dynamicValue.Addr().Interface(), finalField.Addr().Interface(), true) if err != nil { return 0, err } } else if dynamicFieldType.Kind() == reflect.Ptr && dynamicFieldType.Elem().Kind() == reflect.Struct { // if field is a pointer, find if the underlying is a struct - _, err := mapDynamicToFinal(dynamicValue.Interface(), finalField.Interface()) + _, err := mapDynamicToFinal(dynamicValue.Interface(), finalField.Interface(), true) if err != nil { return 0, err } @@ -178,7 +182,7 @@ func mapDynamicToFinal(dynamic any, final any) (uint64, error) { for j := 0; j < dynamicValue.Len(); j++ { sliceElem := dynamicValue.Index(j).Addr().Interface() finalSliceElem := reflect.New(finalField.Type().Elem()).Elem() - _, err := mapDynamicToFinal(sliceElem, finalSliceElem.Addr().Interface()) + _, err := mapDynamicToFinal(sliceElem, finalSliceElem.Addr().Interface(), true) if err != nil { return 0, err } diff --git a/api_test.go b/api_test.go index 4a68ce6..f7f15e0 100644 --- a/api_test.go +++ b/api_test.go @@ -413,7 +413,7 @@ type Branch struct { Proj Project `json:"proj,omitempty"` } -func TestReverseEdgeMutation(t *testing.T) { +func TestReverseEdgeGet(t *testing.T) { ctx := context.Background() db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) require.NoError(t, err) @@ -437,29 +437,8 @@ func TestReverseEdgeMutation(t *testing.T) { require.Equal(t, "P", project.Name) require.Equal(t, project.Gid, projGid) - //modifying a read-only field will be a no-op + // modifying a read-only field will be a no-op require.Len(t, project.Branches, 0) -} - -func TestReverseEdgeGet(t *testing.T) { - ctx := context.Background() - db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) - require.NoError(t, err) - defer db.Close() - - db1, err := db.CreateNamespace() - require.NoError(t, err) - - require.NoError(t, db1.DropData(ctx)) - - projGid, project, err := modusdb.Create(db, Project{ - Name: "P", - ClerkId: "456", - }, db1.ID()) - require.NoError(t, err) - - require.Equal(t, "P", project.Name) - require.Equal(t, project.Gid, projGid) branch1 := Branch{ Name: "B", @@ -506,8 +485,18 @@ func TestReverseEdgeGet(t *testing.T) { require.Equal(t, "B", queriedBranches[0].Name) require.Equal(t, "B2", queriedBranches[1].Name) - // max depth is 2 + // max depth is 2, so we should not see the branches within project require.Len(t, queriedBranches[0].Proj.Branches, 0) + + _, _, err = modusdb.Delete[Project](db, projGid, db1.ID()) + require.NoError(t, err) + + queryBranchesGids, queriedBranches, err = modusdb.Query[Branch](db, modusdb.QueryParams{}, db1.ID()) + require.NoError(t, err) + require.Len(t, queriedBranches, 2) + require.Len(t, queryBranchesGids, 2) + require.Equal(t, "B", queriedBranches[0].Name) + require.Equal(t, "B2", queriedBranches[1].Name) } func TestReverseEdgeQuery(t *testing.T) { From 6f113ed108f562558eeb05192ba9c99645933eca Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 3 Jan 2025 08:24:36 -0800 Subject: [PATCH 09/17] refactoring packages --- api.go | 7 +- api/dql_query/dql_query.go | 173 +++++++++++++++++++++++++ api/utils/reflect.go | 194 ++++++++++++++++++++++++++++ utils.go => api/utils/utils.go | 11 +- api_dql.go | 227 --------------------------------- api_mutate_helper.go | 25 ++-- api_query_helper.go | 31 +++-- api_reflect.go | 202 ++--------------------------- api_test.go | 5 +- api_types.go | 82 +++++++++--- 10 files changed, 482 insertions(+), 475 deletions(-) create mode 100644 api/dql_query/dql_query.go create mode 100644 api/utils/reflect.go rename utils.go => api/utils/utils.go (62%) delete mode 100644 api_dql.go diff --git a/api.go b/api.go index 652ea19..311baf0 100644 --- a/api.go +++ b/api.go @@ -14,6 +14,7 @@ import ( "github.com/dgraph-io/dgraph/v24/dql" "github.com/dgraph-io/dgraph/v24/schema" + "github.com/hypermodeinc/modusdb/api/utils" ) func Create[T any](db *DB, object T, ns ...uint64) (uint64, T, error) { @@ -66,7 +67,7 @@ func Upsert[T any](db *DB, object T, ns ...uint64) (uint64, T, bool, error) { return 0, object, false, err } - gid, cf, err := getUniqueConstraint[T](object) + gid, cf, err := GetUniqueConstraint[T](object) if err != nil { return 0, object, false, err } @@ -85,13 +86,13 @@ func Upsert[T any](db *DB, object T, ns ...uint64) (uint64, T, bool, error) { if gid != 0 { gid, _, err = getByGidWithObject[T](ctx, n, gid, object) - if err != nil && err != ErrNoObjFound { + if err != nil && err != utils.ErrNoObjFound { return 0, object, false, err } wasFound = err == nil } else if cf != nil { gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) - if err != nil && err != ErrNoObjFound { + if err != nil && err != utils.ErrNoObjFound { return 0, object, false, err } wasFound = err == nil diff --git a/api/dql_query/dql_query.go b/api/dql_query/dql_query.go new file mode 100644 index 0000000..54f09d0 --- /dev/null +++ b/api/dql_query/dql_query.go @@ -0,0 +1,173 @@ +package dql_query + +import ( + "fmt" + "strconv" + "strings" +) + +type QueryFunc func() string + +const ( + ObjQuery = ` + { + obj(func: %s) { + gid: uid + expand(_all_) { + gid: uid + expand(_all_) + dgraph.type + } + dgraph.type + %s + } + } + ` + + ObjsQuery = ` + { + objs(func: type("%s")%s) @filter(%s) { + gid: uid + expand(_all_) { + gid: uid + expand(_all_) + dgraph.type + } + dgraph.type + %s + } + } + ` + + ReverseEdgeQuery = ` + %s: ~%s { + gid: uid + expand(_all_) + dgraph.type + } + ` + + FuncUid = `uid(%d)` + FuncEq = `eq(%s, %s)` + FuncSimilarTo = `similar_to(%s, %d, "[%s]")` + FuncAllOfTerms = `allofterms(%s, "%s")` + FuncAnyOfTerms = `anyofterms(%s, "%s")` + FuncAllOfText = `alloftext(%s, "%s")` + FuncAnyOfText = `anyoftext(%s, "%s")` + FuncRegExp = `regexp(%s, /%s/)` + FuncLe = `le(%s, %s)` + FuncGe = `ge(%s, %s)` + FuncGt = `gt(%s, %s)` + FuncLt = `lt(%s, %s)` +) + +func BuildUidQuery(gid uint64) QueryFunc { + return func() string { + return fmt.Sprintf(FuncUid, gid) + } +} + +func BuildEqQuery(key string, value any) QueryFunc { + return func() string { + return fmt.Sprintf(FuncEq, key, value) + } +} + +func BuildSimilarToQuery(indexAttr string, topK int64, vec []float32) QueryFunc { + vecStrArr := make([]string, len(vec)) + for i := range vec { + vecStrArr[i] = strconv.FormatFloat(float64(vec[i]), 'f', -1, 32) + } + vecStr := strings.Join(vecStrArr, ",") + return func() string { + return fmt.Sprintf(FuncSimilarTo, indexAttr, topK, vecStr) + } +} + +func BuildAllOfTermsQuery(attr string, terms string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncAllOfTerms, attr, terms) + } +} + +func BuildAnyOfTermsQuery(attr string, terms string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncAnyOfTerms, attr, terms) + } +} + +func BuildAllOfTextQuery(attr, text string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncAllOfText, attr, text) + } +} + +func BuildAnyOfTextQuery(attr, text string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncAnyOfText, attr, text) + } +} + +func BuildRegExpQuery(attr, pattern string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncRegExp, attr, pattern) + } +} + +func BuildLeQuery(attr, value string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncLe, attr, value) + } +} + +func BuildGeQuery(attr, value string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncGe, attr, value) + } +} + +func BuildGtQuery(attr, value string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncGt, attr, value) + } +} + +func BuildLtQuery(attr, value string) QueryFunc { + return func() string { + return fmt.Sprintf(FuncLt, attr, value) + } +} + +func And(qfs ...QueryFunc) QueryFunc { + return func() string { + qs := make([]string, len(qfs)) + for i, qf := range qfs { + qs[i] = qf() + } + return strings.Join(qs, " AND ") + } +} + +func Or(qfs ...QueryFunc) QueryFunc { + return func() string { + qs := make([]string, len(qfs)) + for i, qf := range qfs { + qs[i] = qf() + } + return strings.Join(qs, " OR ") + } +} + +func Not(qf QueryFunc) QueryFunc { + return func() string { + return "NOT " + qf() + } +} + +func FormatObjQuery(qf QueryFunc, extraFields string) string { + return fmt.Sprintf(ObjQuery, qf(), extraFields) +} + +func FormatObjsQuery(typeName string, qf QueryFunc, paginationAndSorting string, extraFields string) string { + return fmt.Sprintf(ObjsQuery, typeName, paginationAndSorting, qf(), extraFields) +} diff --git a/api/utils/reflect.go b/api/utils/reflect.go new file mode 100644 index 0000000..c38d0f1 --- /dev/null +++ b/api/utils/reflect.go @@ -0,0 +1,194 @@ +package utils + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +type DbTag struct { + Constraint string +} + +func GetFieldTags(t reflect.Type) (fieldToJsonTags map[string]string, + jsonToDbTags map[string]*DbTag, jsonToReverseEdgeTags map[string]string, err error) { + + fieldToJsonTags = make(map[string]string) + jsonToDbTags = make(map[string]*DbTag) + jsonToReverseEdgeTags = make(map[string]string) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + jsonTag := field.Tag.Get("json") + if jsonTag == "" { + return nil, nil, nil, fmt.Errorf("field %s has no json tag", field.Name) + } + jsonName := strings.Split(jsonTag, ",")[0] + fieldToJsonTags[field.Name] = jsonName + + reverseEdgeTag := field.Tag.Get("readFrom") + if reverseEdgeTag != "" { + typeAndField := strings.Split(reverseEdgeTag, ",") + if len(typeAndField) != 2 { + return nil, nil, nil, fmt.Errorf(`field %s has invalid readFrom tag, + expected format is type=,field=`, field.Name) + } + t := strings.Split(typeAndField[0], "=")[1] + f := strings.Split(typeAndField[1], "=")[1] + jsonToReverseEdgeTags[jsonName] = GetPredicateName(t, f) + } + + dbConstraintsTag := field.Tag.Get("db") + if dbConstraintsTag != "" { + jsonToDbTags[jsonName] = &DbTag{} + dbTagsSplit := strings.Split(dbConstraintsTag, ",") + for _, dbTag := range dbTagsSplit { + split := strings.Split(dbTag, "=") + if split[0] == "constraint" { + jsonToDbTags[jsonName].Constraint = split[1] + } + } + } + } + return fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, nil +} + +func GetJsonTagToValues(object any, fieldToJsonTags map[string]string) map[string]any { + values := make(map[string]any) + v := reflect.ValueOf(object) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + for fieldName, jsonName := range fieldToJsonTags { + fieldValue := v.FieldByName(fieldName) + values[jsonName] = fieldValue.Interface() + + } + return values +} + +func CreateDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, depth int) reflect.Type { + fields := make([]reflect.StructField, 0, len(fieldToJsonTags)) + for fieldName, jsonName := range fieldToJsonTags { + field, _ := t.FieldByName(fieldName) + if fieldName != "Gid" { + if field.Type.Kind() == reflect.Struct { + if depth <= 1 { + nestedFieldToJsonTags, _, _, _ := GetFieldTags(field.Type) + nestedType := CreateDynamicStruct(field.Type, nestedFieldToJsonTags, depth+1) + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: nestedType, + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } + } else if field.Type.Kind() == reflect.Ptr && + field.Type.Elem().Kind() == reflect.Struct { + nestedFieldToJsonTags, _, _, _ := GetFieldTags(field.Type.Elem()) + nestedType := CreateDynamicStruct(field.Type.Elem(), nestedFieldToJsonTags, depth+1) + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: reflect.PointerTo(nestedType), + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } else if field.Type.Kind() == reflect.Slice && + field.Type.Elem().Kind() == reflect.Struct { + nestedFieldToJsonTags, _, _, _ := GetFieldTags(field.Type.Elem()) + nestedType := CreateDynamicStruct(field.Type.Elem(), nestedFieldToJsonTags, depth+1) + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: reflect.SliceOf(nestedType), + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } else { + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: field.Type, + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) + } + + } + } + fields = append(fields, reflect.StructField{ + Name: "Gid", + Type: reflect.TypeOf(""), + Tag: reflect.StructTag(`json:"gid"`), + }, reflect.StructField{ + Name: "DgraphType", + Type: reflect.TypeOf([]string{}), + Tag: reflect.StructTag(`json:"dgraph.type"`), + }) + return reflect.StructOf(fields) +} + +func MapDynamicToFinal(dynamic any, final any, isNested bool) (uint64, error) { + vFinal := reflect.ValueOf(final).Elem() + vDynamic := reflect.ValueOf(dynamic).Elem() + + gid := uint64(0) + + for i := 0; i < vDynamic.NumField(); i++ { + + dynamicField := vDynamic.Type().Field(i) + dynamicFieldType := dynamicField.Type + dynamicValue := vDynamic.Field(i) + + var finalField reflect.Value + if dynamicField.Name == "Gid" { + finalField = vFinal.FieldByName("Gid") + gidStr := dynamicValue.String() + gid, _ = strconv.ParseUint(gidStr, 0, 64) + } else if dynamicField.Name == "DgraphType" { + fieldArrInterface := dynamicValue.Interface() + fieldArr, ok := fieldArrInterface.([]string) + if ok { + if len(fieldArr) == 0 { + if !isNested { + return 0, ErrNoObjFound + } else { + continue + } + } + } else { + return 0, fmt.Errorf("DgraphType field should be an array of strings") + } + } else { + finalField = vFinal.FieldByName(dynamicField.Name) + } + if dynamicFieldType.Kind() == reflect.Struct { + _, err := MapDynamicToFinal(dynamicValue.Addr().Interface(), finalField.Addr().Interface(), true) + if err != nil { + return 0, err + } + } else if dynamicFieldType.Kind() == reflect.Ptr && + dynamicFieldType.Elem().Kind() == reflect.Struct { + // if field is a pointer, find if the underlying is a struct + _, err := MapDynamicToFinal(dynamicValue.Interface(), finalField.Interface(), true) + if err != nil { + return 0, err + } + } else if dynamicFieldType.Kind() == reflect.Slice && + dynamicFieldType.Elem().Kind() == reflect.Struct { + for j := 0; j < dynamicValue.Len(); j++ { + sliceElem := dynamicValue.Index(j).Addr().Interface() + finalSliceElem := reflect.New(finalField.Type().Elem()).Elem() + _, err := MapDynamicToFinal(sliceElem, finalSliceElem.Addr().Interface(), true) + if err != nil { + return 0, err + } + finalField.Set(reflect.Append(finalField, finalSliceElem)) + } + } else { + if finalField.IsValid() && finalField.CanSet() { + // if field name is gid, convert it to uint64 + if dynamicField.Name == "Gid" { + finalField.SetUint(gid) + } else { + finalField.Set(dynamicValue) + } + } + } + } + return gid, nil +} diff --git a/utils.go b/api/utils/utils.go similarity index 62% rename from utils.go rename to api/utils/utils.go index e571bb2..ba726ed 100644 --- a/utils.go +++ b/api/utils/utils.go @@ -7,7 +7,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package modusdb +package utils import ( "fmt" @@ -15,10 +15,15 @@ import ( "github.com/dgraph-io/dgraph/v24/x" ) -func getPredicateName(typeName, fieldName string) string { +var ( + ErrNoObjFound = fmt.Errorf("no object found") + NoUniqueConstr = "unique constraint not defined for any field on type %s" +) + +func GetPredicateName(typeName, fieldName string) string { return fmt.Sprint(typeName, ".", fieldName) } -func addNamespace(ns uint64, pred string) string { +func AddNamespace(ns uint64, pred string) string { return x.NamespaceAttr(ns, pred) } diff --git a/api_dql.go b/api_dql.go deleted file mode 100644 index 6c63171..0000000 --- a/api_dql.go +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Copyright 2025 Hypermode Inc. - * Licensed under the terms of the Apache License, Version 2.0 - * See the LICENSE file that accompanied this code for further details. - * - * SPDX-FileCopyrightText: 2025 Hypermode Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package modusdb - -import ( - "fmt" - "strconv" - "strings" -) - -type QueryFunc func() string - -const ( - objQuery = ` - { - obj(func: %s) { - gid: uid - expand(_all_) { - gid: uid - expand(_all_) - dgraph.type - } - dgraph.type - %s - } - } - ` - - objsQuery = ` - { - objs(func: type("%s")%s) @filter(%s) { - gid: uid - expand(_all_) { - gid: uid - expand(_all_) - dgraph.type - } - dgraph.type - %s - } - } - ` - - reverseEdgeQuery = ` - %s: ~%s { - gid: uid - expand(_all_) - dgraph.type - } - ` - - funcUid = `uid(%d)` - funcEq = `eq(%s, %s)` - funcSimilarTo = `similar_to(%s, %d, "[%s]")` - funcAllOfTerms = `allofterms(%s, "%s")` - funcAnyOfTerms = `anyofterms(%s, "%s")` - funcAllOfText = `alloftext(%s, "%s")` - funcAnyOfText = `anyoftext(%s, "%s")` - funcRegExp = `regexp(%s, /%s/)` - funcLe = `le(%s, %s)` - funcGe = `ge(%s, %s)` - funcGt = `gt(%s, %s)` - funcLt = `lt(%s, %s)` -) - -func buildUidQuery(gid uint64) QueryFunc { - return func() string { - return fmt.Sprintf(funcUid, gid) - } -} - -func buildEqQuery(key string, value any) QueryFunc { - return func() string { - return fmt.Sprintf(funcEq, key, value) - } -} - -func buildSimilarToQuery(indexAttr string, topK int64, vec []float32) QueryFunc { - vecStrArr := make([]string, len(vec)) - for i := range vec { - vecStrArr[i] = strconv.FormatFloat(float64(vec[i]), 'f', -1, 32) - } - vecStr := strings.Join(vecStrArr, ",") - return func() string { - return fmt.Sprintf(funcSimilarTo, indexAttr, topK, vecStr) - } -} - -func buildAllOfTermsQuery(attr string, terms string) QueryFunc { - return func() string { - return fmt.Sprintf(funcAllOfTerms, attr, terms) - } -} - -func buildAnyOfTermsQuery(attr string, terms string) QueryFunc { - return func() string { - return fmt.Sprintf(funcAnyOfTerms, attr, terms) - } -} - -func buildAllOfTextQuery(attr, text string) QueryFunc { - return func() string { - return fmt.Sprintf(funcAllOfText, attr, text) - } -} - -func buildAnyOfTextQuery(attr, text string) QueryFunc { - return func() string { - return fmt.Sprintf(funcAnyOfText, attr, text) - } -} - -func buildRegExpQuery(attr, pattern string) QueryFunc { - return func() string { - return fmt.Sprintf(funcRegExp, attr, pattern) - } -} - -func buildLeQuery(attr, value string) QueryFunc { - return func() string { - return fmt.Sprintf(funcLe, attr, value) - } -} - -func buildGeQuery(attr, value string) QueryFunc { - return func() string { - return fmt.Sprintf(funcGe, attr, value) - } -} - -func buildGtQuery(attr, value string) QueryFunc { - return func() string { - return fmt.Sprintf(funcGt, attr, value) - } -} - -func buildLtQuery(attr, value string) QueryFunc { - return func() string { - return fmt.Sprintf(funcLt, attr, value) - } -} - -func And(qfs ...QueryFunc) QueryFunc { - return func() string { - qs := make([]string, len(qfs)) - for i, qf := range qfs { - qs[i] = qf() - } - return strings.Join(qs, " AND ") - } -} - -func Or(qfs ...QueryFunc) QueryFunc { - return func() string { - qs := make([]string, len(qfs)) - for i, qf := range qfs { - qs[i] = qf() - } - return strings.Join(qs, " OR ") - } -} - -func Not(qf QueryFunc) QueryFunc { - return func() string { - return "NOT " + qf() - } -} - -func formatObjQuery(qf QueryFunc, extraFields string) string { - return fmt.Sprintf(objQuery, qf(), extraFields) -} - -func formatObjsQuery(typeName string, qf QueryFunc, paginationAndSorting string, extraFields string) string { - return fmt.Sprintf(objsQuery, typeName, paginationAndSorting, qf(), extraFields) -} - -// Helper function to combine multiple filters -func filtersToQueryFunc(typeName string, filter Filter) QueryFunc { - return filterToQueryFunc(typeName, filter) -} - -func paginationToQueryString(p Pagination) string { - paginationStr := "" - if p.Limit > 0 { - paginationStr += ", " + fmt.Sprintf("first: %d", p.Limit) - } - if p.Offset > 0 { - paginationStr += ", " + fmt.Sprintf("offset: %d", p.Offset) - } else if p.After != "" { - paginationStr += ", " + fmt.Sprintf("after: %s", p.After) - } - if paginationStr == "" { - return "" - } - return paginationStr -} - -func sortingToQueryString(typeName string, s Sorting) string { - if s.OrderAscField == "" && s.OrderDescField == "" { - return "" - } - - var parts []string - first, second := s.OrderDescField, s.OrderAscField - firstOp, secondOp := "orderdesc", "orderasc" - - if !s.OrderDescFirst { - first, second = s.OrderAscField, s.OrderDescField - firstOp, secondOp = "orderasc", "orderdesc" - } - - if first != "" { - parts = append(parts, fmt.Sprintf("%s: %s", firstOp, getPredicateName(typeName, first))) - } - if second != "" { - parts = append(parts, fmt.Sprintf("%s: %s", secondOp, getPredicateName(typeName, second))) - } - - return ", " + strings.Join(parts, ", ") -} diff --git a/api_mutate_helper.go b/api_mutate_helper.go index f90c258..ffea6f8 100644 --- a/api_mutate_helper.go +++ b/api_mutate_helper.go @@ -22,6 +22,7 @@ import ( "github.com/dgraph-io/dgraph/v24/schema" "github.com/dgraph-io/dgraph/v24/worker" "github.com/dgraph-io/dgraph/v24/x" + "github.com/hypermodeinc/modusdb/api/utils" ) func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object T, @@ -31,11 +32,11 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac return fmt.Errorf("expected struct, got %s", t.Kind()) } - fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, err := getFieldTags(t) + fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, err := utils.GetFieldTags(t) if err != nil { return err } - jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) + jsonTagToValue := utils.GetJsonTagToValues(object, fieldToJsonTags) nquads := make([]*api.NQuad, 0) uniqueConstraintFound := false @@ -53,13 +54,13 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac reverseEdge := jsonToReverseEdgeTags[jsonName] typeName := strings.Split(reverseEdge, ".")[0] u := &pb.SchemaUpdate{ - Predicate: addNamespace(n.id, reverseEdge), + Predicate: utils.AddNamespace(n.id, reverseEdge), ValueType: pb.Posting_UID, Directive: pb.SchemaUpdate_REVERSE, } sch.Preds = append(sch.Preds, u) sch.Types = append(sch.Types, &pb.TypeUpdate{ - TypeName: addNamespace(n.id, typeName), + TypeName: utils.AddNamespace(n.id, typeName), Fields: []*pb.SchemaUpdate{u}, }) continue @@ -101,11 +102,11 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac nquad = &api.NQuad{ Namespace: n.ID(), Subject: fmt.Sprint(gid), - Predicate: getPredicateName(t.Name(), jsonName), + Predicate: utils.GetPredicateName(t.Name(), jsonName), } u := &pb.SchemaUpdate{ - Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), + Predicate: utils.AddNamespace(n.id, utils.GetPredicateName(t.Name(), jsonName)), ValueType: valType, } @@ -117,7 +118,7 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac } if jsonToDbTags[jsonName] != nil { - constraint := jsonToDbTags[jsonName].constraint + constraint := jsonToDbTags[jsonName].Constraint if constraint == "vector" && valType != pb.Posting_VFLOAT { return fmt.Errorf("vector index can only be applied to []float values") } @@ -128,10 +129,10 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac nquads = append(nquads, nquad) } if !uniqueConstraintFound { - return fmt.Errorf(NoUniqueConstr, t.Name()) + return fmt.Errorf(utils.NoUniqueConstr, t.Name()) } sch.Types = append(sch.Types, &pb.TypeUpdate{ - TypeName: addNamespace(n.id, t.Name()), + TypeName: utils.AddNamespace(n.id, t.Name()), Fields: sch.Preds, }) @@ -209,7 +210,7 @@ func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { } func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { - gid, cf, err := getUniqueConstraint[T](object) + gid, cf, err := GetUniqueConstraint[T](object) if err != nil { return 0, err } @@ -227,7 +228,7 @@ func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) } if gid != 0 { gid, _, err = getByGidWithObject[T](ctx, n, gid, object) - if err != nil && err != ErrNoObjFound { + if err != nil && err != utils.ErrNoObjFound { return 0, err } if err == nil { @@ -235,7 +236,7 @@ func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) } } else if cf != nil { gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) - if err != nil && err != ErrNoObjFound { + if err != nil && err != utils.ErrNoObjFound { return 0, err } if err == nil { diff --git a/api_query_helper.go b/api_query_helper.go index 9ec211c..3b72d55 100644 --- a/api_query_helper.go +++ b/api_query_helper.go @@ -14,6 +14,9 @@ import ( "encoding/json" "fmt" "reflect" + + "github.com/hypermodeinc/modusdb/api/dql_query" + "github.com/hypermodeinc/modusdb/api/utils" ) func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, T, error) { @@ -47,14 +50,14 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac obj T, withReverse bool, args ...R) (uint64, T, error) { t := reflect.TypeOf(obj) - fieldToJsonTags, jsonToDbTag, jsonToReverseEdgeTags, err := getFieldTags(t) + fieldToJsonTags, jsonToDbTag, jsonToReverseEdgeTags, err := utils.GetFieldTags(t) if err != nil { return 0, obj, err } readFromQuery := "" if withReverse { for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { - readFromQuery += fmt.Sprintf(reverseEdgeQuery, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + readFromQuery += fmt.Sprintf(dql_query.ReverseEdgeQuery, utils.GetPredicateName(t.Name(), jsonTag), reverseEdgeTag) } } @@ -62,14 +65,14 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac var query string gid, ok := any(args[0]).(uint64) if ok { - query = formatObjQuery(buildUidQuery(gid), readFromQuery) + query = dql_query.FormatObjQuery(dql_query.BuildUidQuery(gid), readFromQuery) } else if cf, ok = any(args[0]).(ConstrainedField); ok { - query = formatObjQuery(buildEqQuery(getPredicateName(t.Name(), cf.Key), cf.Value), readFromQuery) + query = dql_query.FormatObjQuery(dql_query.BuildEqQuery(utils.GetPredicateName(t.Name(), cf.Key), cf.Value), readFromQuery) } else { return 0, obj, fmt.Errorf("invalid unique field type") } - if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" { + if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].Constraint == "" { return 0, obj, fmt.Errorf("constraint not defined for field %s", cf.Key) } @@ -78,7 +81,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac return 0, obj, err } - dynamicType := createDynamicStruct(t, fieldToJsonTags, 1) + dynamicType := utils.CreateDynamicStruct(t, fieldToJsonTags, 1) dynamicInstance := reflect.New(dynamicType).Interface() @@ -95,12 +98,12 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac // Check if we have at least one object in the response if len(result.Obj) == 0 { - return 0, obj, ErrNoObjFound + return 0, obj, utils.ErrNoObjFound } // Map the dynamic struct to the final type T finalObject := reflect.New(t).Interface() - gid, err = mapDynamicToFinal(result.Obj[0], finalObject, false) + gid, err = utils.MapDynamicToFinal(result.Obj[0], finalObject, false) if err != nil { return 0, obj, err } @@ -120,12 +123,12 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar withReverse bool) ([]uint64, []T, error) { var obj T t := reflect.TypeOf(obj) - fieldToJsonTags, _, jsonToReverseEdgeTags, err := getFieldTags(t) + fieldToJsonTags, _, jsonToReverseEdgeTags, err := utils.GetFieldTags(t) if err != nil { return nil, nil, err } - var filterQueryFunc QueryFunc = func() string { + var filterQueryFunc dql_query.QueryFunc = func() string { return "" } var paginationAndSorting string @@ -146,18 +149,18 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar readFromQuery := "" if withReverse { for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { - readFromQuery += fmt.Sprintf(reverseEdgeQuery, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + readFromQuery += fmt.Sprintf(dql_query.ReverseEdgeQuery, utils.GetPredicateName(t.Name(), jsonTag), reverseEdgeTag) } } - query := formatObjsQuery(t.Name(), filterQueryFunc, paginationAndSorting, readFromQuery) + query := dql_query.FormatObjsQuery(t.Name(), filterQueryFunc, paginationAndSorting, readFromQuery) resp, err := n.queryWithLock(ctx, query) if err != nil { return nil, nil, err } - dynamicType := createDynamicStruct(t, fieldToJsonTags, 1) + dynamicType := utils.CreateDynamicStruct(t, fieldToJsonTags, 1) var result struct { Objs []any `json:"objs"` @@ -185,7 +188,7 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar var objs []T for _, obj := range result.Objs { finalObject := reflect.New(t).Interface() - gid, err := mapDynamicToFinal(obj, finalObject, false) + gid, err := utils.MapDynamicToFinal(obj, finalObject, false) if err != nil { return nil, nil, err } diff --git a/api_reflect.go b/api_reflect.go index 5a32b01..029a223 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -12,203 +12,17 @@ package modusdb import ( "fmt" "reflect" - "strconv" - "strings" -) - -type dbTag struct { - constraint string -} - -func getFieldTags(t reflect.Type) (fieldToJsonTags map[string]string, - jsonToDbTags map[string]*dbTag, jsonToReverseEdgeTags map[string]string, err error) { - - fieldToJsonTags = make(map[string]string) - jsonToDbTags = make(map[string]*dbTag) - jsonToReverseEdgeTags = make(map[string]string) - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - jsonTag := field.Tag.Get("json") - if jsonTag == "" { - return nil, nil, nil, fmt.Errorf("field %s has no json tag", field.Name) - } - jsonName := strings.Split(jsonTag, ",")[0] - fieldToJsonTags[field.Name] = jsonName - - reverseEdgeTag := field.Tag.Get("readFrom") - if reverseEdgeTag != "" { - typeAndField := strings.Split(reverseEdgeTag, ",") - if len(typeAndField) != 2 { - return nil, nil, nil, fmt.Errorf(`field %s has invalid readFrom tag, - expected format is type=,field=`, field.Name) - } - t := strings.Split(typeAndField[0], "=")[1] - f := strings.Split(typeAndField[1], "=")[1] - jsonToReverseEdgeTags[jsonName] = getPredicateName(t, f) - } - - dbConstraintsTag := field.Tag.Get("db") - if dbConstraintsTag != "" { - jsonToDbTags[jsonName] = &dbTag{} - dbTagsSplit := strings.Split(dbConstraintsTag, ",") - for _, dbTag := range dbTagsSplit { - split := strings.Split(dbTag, "=") - if split[0] == "constraint" { - jsonToDbTags[jsonName].constraint = split[1] - } - } - } - } - return fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, nil -} - -func getJsonTagToValues(object any, fieldToJsonTags map[string]string) map[string]any { - values := make(map[string]any) - v := reflect.ValueOf(object) - for v.Kind() == reflect.Ptr { - v = v.Elem() - } - for fieldName, jsonName := range fieldToJsonTags { - fieldValue := v.FieldByName(fieldName) - values[jsonName] = fieldValue.Interface() - } - return values -} - -func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, depth int) reflect.Type { - fields := make([]reflect.StructField, 0, len(fieldToJsonTags)) - for fieldName, jsonName := range fieldToJsonTags { - field, _ := t.FieldByName(fieldName) - if fieldName != "Gid" { - if field.Type.Kind() == reflect.Struct { - if depth <= 1 { - nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type) - nestedType := createDynamicStruct(field.Type, nestedFieldToJsonTags, depth+1) - fields = append(fields, reflect.StructField{ - Name: field.Name, - Type: nestedType, - Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), - }) - } - } else if field.Type.Kind() == reflect.Ptr && - field.Type.Elem().Kind() == reflect.Struct { - nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type.Elem()) - nestedType := createDynamicStruct(field.Type.Elem(), nestedFieldToJsonTags, depth+1) - fields = append(fields, reflect.StructField{ - Name: field.Name, - Type: reflect.PointerTo(nestedType), - Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), - }) - } else if field.Type.Kind() == reflect.Slice && - field.Type.Elem().Kind() == reflect.Struct { - nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type.Elem()) - nestedType := createDynamicStruct(field.Type.Elem(), nestedFieldToJsonTags, depth+1) - fields = append(fields, reflect.StructField{ - Name: field.Name, - Type: reflect.SliceOf(nestedType), - Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), - }) - } else { - fields = append(fields, reflect.StructField{ - Name: field.Name, - Type: field.Type, - Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), - }) - } - - } - } - fields = append(fields, reflect.StructField{ - Name: "Gid", - Type: reflect.TypeOf(""), - Tag: reflect.StructTag(`json:"gid"`), - }, reflect.StructField{ - Name: "DgraphType", - Type: reflect.TypeOf([]string{}), - Tag: reflect.StructTag(`json:"dgraph.type"`), - }) - return reflect.StructOf(fields) -} - -func mapDynamicToFinal(dynamic any, final any, isNested bool) (uint64, error) { - vFinal := reflect.ValueOf(final).Elem() - vDynamic := reflect.ValueOf(dynamic).Elem() - - gid := uint64(0) - - for i := 0; i < vDynamic.NumField(); i++ { - - dynamicField := vDynamic.Type().Field(i) - dynamicFieldType := dynamicField.Type - dynamicValue := vDynamic.Field(i) - - var finalField reflect.Value - if dynamicField.Name == "Gid" { - finalField = vFinal.FieldByName("Gid") - gidStr := dynamicValue.String() - gid, _ = strconv.ParseUint(gidStr, 0, 64) - } else if dynamicField.Name == "DgraphType" { - fieldArrInterface := dynamicValue.Interface() - fieldArr, ok := fieldArrInterface.([]string) - if ok { - if len(fieldArr) == 0 { - if !isNested { - return 0, ErrNoObjFound - } else { - continue - } - } - } else { - return 0, fmt.Errorf("DgraphType field should be an array of strings") - } - } else { - finalField = vFinal.FieldByName(dynamicField.Name) - } - if dynamicFieldType.Kind() == reflect.Struct { - _, err := mapDynamicToFinal(dynamicValue.Addr().Interface(), finalField.Addr().Interface(), true) - if err != nil { - return 0, err - } - } else if dynamicFieldType.Kind() == reflect.Ptr && - dynamicFieldType.Elem().Kind() == reflect.Struct { - // if field is a pointer, find if the underlying is a struct - _, err := mapDynamicToFinal(dynamicValue.Interface(), finalField.Interface(), true) - if err != nil { - return 0, err - } - } else if dynamicFieldType.Kind() == reflect.Slice && - dynamicFieldType.Elem().Kind() == reflect.Struct { - for j := 0; j < dynamicValue.Len(); j++ { - sliceElem := dynamicValue.Index(j).Addr().Interface() - finalSliceElem := reflect.New(finalField.Type().Elem()).Elem() - _, err := mapDynamicToFinal(sliceElem, finalSliceElem.Addr().Interface(), true) - if err != nil { - return 0, err - } - finalField.Set(reflect.Append(finalField, finalSliceElem)) - } - } else { - if finalField.IsValid() && finalField.CanSet() { - // if field name is gid, convert it to uint64 - if dynamicField.Name == "Gid" { - finalField.SetUint(gid) - } else { - finalField.Set(dynamicValue) - } - } - } - } - return gid, nil -} + "github.com/hypermodeinc/modusdb/api/utils" +) -func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { +func GetUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { t := reflect.TypeOf(object) - fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) + fieldToJsonTags, jsonToDbTags, _, err := utils.GetFieldTags(t) if err != nil { return 0, nil, err } - jsonTagToValue := getJsonTagToValues(object, fieldToJsonTags) + jsonTagToValue := utils.GetJsonTagToValues(object, fieldToJsonTags) for jsonName, value := range jsonTagToValue { if jsonName == "gid" { @@ -220,7 +34,7 @@ func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { return gid, nil, nil } } - if jsonToDbTags[jsonName] != nil && isValidUniqueIndex(jsonToDbTags[jsonName].constraint) { + if jsonToDbTags[jsonName] != nil && IsValidUniqueIndex(jsonToDbTags[jsonName].Constraint) { // check if value is zero or nil if value == reflect.Zero(reflect.TypeOf(value)).Interface() || value == nil { continue @@ -232,9 +46,9 @@ func getUniqueConstraint[T any](object T) (uint64, *ConstrainedField, error) { } } - return 0, nil, fmt.Errorf(NoUniqueConstr, t.Name()) + return 0, nil, fmt.Errorf(utils.NoUniqueConstr, t.Name()) } -func isValidUniqueIndex(name string) bool { +func IsValidUniqueIndex(name string) bool { return name == "unique" } diff --git a/api_test.go b/api_test.go index f7f15e0..cb0ca7a 100644 --- a/api_test.go +++ b/api_test.go @@ -17,6 +17,7 @@ import ( "github.com/stretchr/testify/require" "github.com/hypermodeinc/modusdb" + "github.com/hypermodeinc/modusdb/api/utils" ) type User struct { @@ -767,7 +768,7 @@ func TestNestedObjectMutationWithBadType(t *testing.T) { _, _, err = modusdb.Create(db, branch, db1.ID()) require.Error(t, err) - require.Equal(t, fmt.Sprintf(modusdb.NoUniqueConstr, "BadProject"), err.Error()) + require.Equal(t, fmt.Sprintf(utils.NoUniqueConstr, "BadProject"), err.Error()) proj := BadProject{ Name: "P", @@ -776,7 +777,7 @@ func TestNestedObjectMutationWithBadType(t *testing.T) { _, _, err = modusdb.Create(db, proj, db1.ID()) require.Error(t, err) - require.Equal(t, fmt.Sprintf(modusdb.NoUniqueConstr, "BadProject"), err.Error()) + require.Equal(t, fmt.Sprintf(utils.NoUniqueConstr, "BadProject"), err.Error()) } diff --git a/api_types.go b/api_types.go index 2384f35..0187b38 100644 --- a/api_types.go +++ b/api_types.go @@ -20,15 +20,12 @@ import ( "github.com/dgraph-io/dgraph/v24/protos/pb" "github.com/dgraph-io/dgraph/v24/types" "github.com/dgraph-io/dgraph/v24/x" + "github.com/hypermodeinc/modusdb/api/dql_query" + "github.com/hypermodeinc/modusdb/api/utils" "github.com/twpayne/go-geom" "github.com/twpayne/go-geom/encoding/wkb" ) -var ( - ErrNoObjFound = fmt.Errorf("no object found") - NoUniqueConstr = "unique constraint not defined for any field on type %s" -) - type UniqueField interface { uint64 | ConstrainedField } @@ -197,53 +194,98 @@ func valueToApiVal(v any) (*api.Value, error) { } } -func filterToQueryFunc(typeName string, f Filter) QueryFunc { +func filterToQueryFunc(typeName string, f Filter) dql_query.QueryFunc { // Handle logical operators first if f.And != nil { - return And(filterToQueryFunc(typeName, *f.And)) + return dql_query.And(filterToQueryFunc(typeName, *f.And)) } if f.Or != nil { - return Or(filterToQueryFunc(typeName, *f.Or)) + return dql_query.Or(filterToQueryFunc(typeName, *f.Or)) } if f.Not != nil { - return Not(filterToQueryFunc(typeName, *f.Not)) + return dql_query.Not(filterToQueryFunc(typeName, *f.Not)) } // Handle field predicates if f.String.Equals != "" { - return buildEqQuery(getPredicateName(typeName, f.Field), f.String.Equals) + return dql_query.BuildEqQuery(utils.GetPredicateName(typeName, f.Field), f.String.Equals) } if len(f.String.AllOfTerms) != 0 { - return buildAllOfTermsQuery(getPredicateName(typeName, f.Field), strings.Join(f.String.AllOfTerms, " ")) + return dql_query.BuildAllOfTermsQuery(utils.GetPredicateName(typeName, f.Field), strings.Join(f.String.AllOfTerms, " ")) } if len(f.String.AnyOfTerms) != 0 { - return buildAnyOfTermsQuery(getPredicateName(typeName, f.Field), strings.Join(f.String.AnyOfTerms, " ")) + return dql_query.BuildAnyOfTermsQuery(utils.GetPredicateName(typeName, f.Field), strings.Join(f.String.AnyOfTerms, " ")) } if len(f.String.AllOfText) != 0 { - return buildAllOfTextQuery(getPredicateName(typeName, f.Field), strings.Join(f.String.AllOfText, " ")) + return dql_query.BuildAllOfTextQuery(utils.GetPredicateName(typeName, f.Field), strings.Join(f.String.AllOfText, " ")) } if len(f.String.AnyOfText) != 0 { - return buildAnyOfTextQuery(getPredicateName(typeName, f.Field), strings.Join(f.String.AnyOfText, " ")) + return dql_query.BuildAnyOfTextQuery(utils.GetPredicateName(typeName, f.Field), strings.Join(f.String.AnyOfText, " ")) } if f.String.RegExp != "" { - return buildRegExpQuery(getPredicateName(typeName, f.Field), f.String.RegExp) + return dql_query.BuildRegExpQuery(utils.GetPredicateName(typeName, f.Field), f.String.RegExp) } if f.String.LessThan != "" { - return buildLtQuery(getPredicateName(typeName, f.Field), f.String.LessThan) + return dql_query.BuildLtQuery(utils.GetPredicateName(typeName, f.Field), f.String.LessThan) } if f.String.LessOrEqual != "" { - return buildLeQuery(getPredicateName(typeName, f.Field), f.String.LessOrEqual) + return dql_query.BuildLeQuery(utils.GetPredicateName(typeName, f.Field), f.String.LessOrEqual) } if f.String.GreaterThan != "" { - return buildGtQuery(getPredicateName(typeName, f.Field), f.String.GreaterThan) + return dql_query.BuildGtQuery(utils.GetPredicateName(typeName, f.Field), f.String.GreaterThan) } if f.String.GreaterOrEqual != "" { - return buildGeQuery(getPredicateName(typeName, f.Field), f.String.GreaterOrEqual) + return dql_query.BuildGeQuery(utils.GetPredicateName(typeName, f.Field), f.String.GreaterOrEqual) } if f.Vector.SimilarTo != nil { - return buildSimilarToQuery(getPredicateName(typeName, f.Field), f.Vector.TopK, f.Vector.SimilarTo) + return dql_query.BuildSimilarToQuery(utils.GetPredicateName(typeName, f.Field), f.Vector.TopK, f.Vector.SimilarTo) } // Return empty query if no conditions match return func() string { return "" } } + +// Helper function to combine multiple filters +func filtersToQueryFunc(typeName string, filter Filter) dql_query.QueryFunc { + return filterToQueryFunc(typeName, filter) +} + +func paginationToQueryString(p Pagination) string { + paginationStr := "" + if p.Limit > 0 { + paginationStr += ", " + fmt.Sprintf("first: %d", p.Limit) + } + if p.Offset > 0 { + paginationStr += ", " + fmt.Sprintf("offset: %d", p.Offset) + } else if p.After != "" { + paginationStr += ", " + fmt.Sprintf("after: %s", p.After) + } + if paginationStr == "" { + return "" + } + return paginationStr +} + +func sortingToQueryString(typeName string, s Sorting) string { + if s.OrderAscField == "" && s.OrderDescField == "" { + return "" + } + + var parts []string + first, second := s.OrderDescField, s.OrderAscField + firstOp, secondOp := "orderdesc", "orderasc" + + if !s.OrderDescFirst { + first, second = s.OrderAscField, s.OrderDescField + firstOp, secondOp = "orderasc", "orderdesc" + } + + if first != "" { + parts = append(parts, fmt.Sprintf("%s: %s", firstOp, utils.GetPredicateName(typeName, first))) + } + if second != "" { + parts = append(parts, fmt.Sprintf("%s: %s", secondOp, utils.GetPredicateName(typeName, second))) + } + + return ", " + strings.Join(parts, ", ") +} From 3a824dc6ca70d6e651be58459c9349593c9582f2 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 3 Jan 2025 08:26:47 -0800 Subject: [PATCH 10/17] . --- api_query_helper.go | 3 ++- api_types.go | 30 ++++++++++++++++++++---------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/api_query_helper.go b/api_query_helper.go index 3b72d55..d49979c 100644 --- a/api_query_helper.go +++ b/api_query_helper.go @@ -67,7 +67,8 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac if ok { query = dql_query.FormatObjQuery(dql_query.BuildUidQuery(gid), readFromQuery) } else if cf, ok = any(args[0]).(ConstrainedField); ok { - query = dql_query.FormatObjQuery(dql_query.BuildEqQuery(utils.GetPredicateName(t.Name(), cf.Key), cf.Value), readFromQuery) + query = dql_query.FormatObjQuery(dql_query.BuildEqQuery(utils.GetPredicateName(t.Name(), + cf.Key), cf.Value), readFromQuery) } else { return 0, obj, fmt.Errorf("invalid unique field type") } diff --git a/api_types.go b/api_types.go index 0187b38..ebcb4b2 100644 --- a/api_types.go +++ b/api_types.go @@ -211,34 +211,44 @@ func filterToQueryFunc(typeName string, f Filter) dql_query.QueryFunc { return dql_query.BuildEqQuery(utils.GetPredicateName(typeName, f.Field), f.String.Equals) } if len(f.String.AllOfTerms) != 0 { - return dql_query.BuildAllOfTermsQuery(utils.GetPredicateName(typeName, f.Field), strings.Join(f.String.AllOfTerms, " ")) + return dql_query.BuildAllOfTermsQuery(utils.GetPredicateName(typeName, + f.Field), strings.Join(f.String.AllOfTerms, " ")) } if len(f.String.AnyOfTerms) != 0 { - return dql_query.BuildAnyOfTermsQuery(utils.GetPredicateName(typeName, f.Field), strings.Join(f.String.AnyOfTerms, " ")) + return dql_query.BuildAnyOfTermsQuery(utils.GetPredicateName(typeName, + f.Field), strings.Join(f.String.AnyOfTerms, " ")) } if len(f.String.AllOfText) != 0 { - return dql_query.BuildAllOfTextQuery(utils.GetPredicateName(typeName, f.Field), strings.Join(f.String.AllOfText, " ")) + return dql_query.BuildAllOfTextQuery(utils.GetPredicateName(typeName, + f.Field), strings.Join(f.String.AllOfText, " ")) } if len(f.String.AnyOfText) != 0 { - return dql_query.BuildAnyOfTextQuery(utils.GetPredicateName(typeName, f.Field), strings.Join(f.String.AnyOfText, " ")) + return dql_query.BuildAnyOfTextQuery(utils.GetPredicateName(typeName, + f.Field), strings.Join(f.String.AnyOfText, " ")) } if f.String.RegExp != "" { - return dql_query.BuildRegExpQuery(utils.GetPredicateName(typeName, f.Field), f.String.RegExp) + return dql_query.BuildRegExpQuery(utils.GetPredicateName(typeName, + f.Field), f.String.RegExp) } if f.String.LessThan != "" { - return dql_query.BuildLtQuery(utils.GetPredicateName(typeName, f.Field), f.String.LessThan) + return dql_query.BuildLtQuery(utils.GetPredicateName(typeName, + f.Field), f.String.LessThan) } if f.String.LessOrEqual != "" { - return dql_query.BuildLeQuery(utils.GetPredicateName(typeName, f.Field), f.String.LessOrEqual) + return dql_query.BuildLeQuery(utils.GetPredicateName(typeName, + f.Field), f.String.LessOrEqual) } if f.String.GreaterThan != "" { - return dql_query.BuildGtQuery(utils.GetPredicateName(typeName, f.Field), f.String.GreaterThan) + return dql_query.BuildGtQuery(utils.GetPredicateName(typeName, + f.Field), f.String.GreaterThan) } if f.String.GreaterOrEqual != "" { - return dql_query.BuildGeQuery(utils.GetPredicateName(typeName, f.Field), f.String.GreaterOrEqual) + return dql_query.BuildGeQuery(utils.GetPredicateName(typeName, + f.Field), f.String.GreaterOrEqual) } if f.Vector.SimilarTo != nil { - return dql_query.BuildSimilarToQuery(utils.GetPredicateName(typeName, f.Field), f.Vector.TopK, f.Vector.SimilarTo) + return dql_query.BuildSimilarToQuery(utils.GetPredicateName(typeName, + f.Field), f.Vector.TopK, f.Vector.SimilarTo) } // Return empty query if no conditions match From 94979bc9fc7a0d87636a0ff3600309bfe462bee5 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 3 Jan 2025 09:58:25 -0800 Subject: [PATCH 11/17] continued refactoring --- api/mutations/mutations.go | 69 +++++ api/{dql_query => query_gen}/dql_query.go | 2 +- api/utils/dgraph.go | 146 +++++++++++ api/utils/reflect.go | 16 ++ api_mutate_helper.go | 300 ---------------------- api_query_helper.go | 47 +--- api_types.go | 125 ++------- api.go => apis.go | 6 +- api_test.go => apis_test.go | 0 mutation_create.go | 121 +++++++++ mutation_helpers.go | 131 ++++++++++ 11 files changed, 516 insertions(+), 447 deletions(-) create mode 100644 api/mutations/mutations.go rename api/{dql_query => query_gen}/dql_query.go (99%) create mode 100644 api/utils/dgraph.go delete mode 100644 api_mutate_helper.go rename api.go => apis.go (95%) rename api_test.go => apis_test.go (100%) create mode 100644 mutation_create.go create mode 100644 mutation_helpers.go diff --git a/api/mutations/mutations.go b/api/mutations/mutations.go new file mode 100644 index 0000000..1217499 --- /dev/null +++ b/api/mutations/mutations.go @@ -0,0 +1,69 @@ +package mutations + +import ( + "fmt" + "reflect" + "strings" + + "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/schema" + "github.com/hypermodeinc/modusdb/api/utils" +) + +func HandleReverseEdge(jsonName string, value reflect.Type, nsId uint64, sch *schema.ParsedSchema, jsonToReverseEdgeTags map[string]string) error { + if jsonToReverseEdgeTags[jsonName] == "" { + return nil + } + + if value.Kind() != reflect.Slice || value.Elem().Kind() != reflect.Struct { + return fmt.Errorf("reverse edge %s should be a slice of structs", jsonName) + } + + reverseEdge := jsonToReverseEdgeTags[jsonName] + typeName := strings.Split(reverseEdge, ".")[0] + u := &pb.SchemaUpdate{ + Predicate: utils.AddNamespace(nsId, reverseEdge), + ValueType: pb.Posting_UID, + Directive: pb.SchemaUpdate_REVERSE, + } + + sch.Preds = append(sch.Preds, u) + sch.Types = append(sch.Types, &pb.TypeUpdate{ + TypeName: utils.AddNamespace(nsId, typeName), + Fields: []*pb.SchemaUpdate{u}, + }) + return nil +} + +func CreateNQuadAndSchema(value any, gid uint64, jsonName string, t reflect.Type, nsId uint64) (*api.NQuad, *pb.SchemaUpdate, error) { + valType, err := utils.ValueToPosting_ValType(value) + if err != nil { + return nil, nil, err + } + + val, err := utils.ValueToApiVal(value) + if err != nil { + return nil, nil, err + } + + nquad := &api.NQuad{ + Namespace: nsId, + Subject: fmt.Sprint(gid), + Predicate: utils.GetPredicateName(t.Name(), jsonName), + } + + u := &pb.SchemaUpdate{ + Predicate: utils.AddNamespace(nsId, utils.GetPredicateName(t.Name(), jsonName)), + ValueType: valType, + } + + if valType == pb.Posting_UID { + nquad.ObjectId = fmt.Sprint(value) + u.Directive = pb.SchemaUpdate_REVERSE + } else { + nquad.ObjectValue = val + } + + return nquad, u, nil +} diff --git a/api/dql_query/dql_query.go b/api/query_gen/dql_query.go similarity index 99% rename from api/dql_query/dql_query.go rename to api/query_gen/dql_query.go index 54f09d0..5922342 100644 --- a/api/dql_query/dql_query.go +++ b/api/query_gen/dql_query.go @@ -1,4 +1,4 @@ -package dql_query +package query_gen import ( "fmt" diff --git a/api/utils/dgraph.go b/api/utils/dgraph.go new file mode 100644 index 0000000..c89cf0c --- /dev/null +++ b/api/utils/dgraph.go @@ -0,0 +1,146 @@ +package utils + +import ( + "encoding/binary" + "fmt" + "time" + + "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/types" + "github.com/twpayne/go-geom" + "github.com/twpayne/go-geom/encoding/wkb" +) + +func addIndex(u *pb.SchemaUpdate, index string, uniqueConstraintExists bool) bool { + u.Directive = pb.SchemaUpdate_INDEX + switch index { + case "exact": + u.Tokenizer = []string{"exact"} + case "term": + u.Tokenizer = []string{"term"} + case "hash": + u.Tokenizer = []string{"hash"} + case "unique": + u.Tokenizer = []string{"exact"} + u.Unique = true + u.Upsert = true + uniqueConstraintExists = true + case "fulltext": + u.Tokenizer = []string{"fulltext"} + case "trigram": + u.Tokenizer = []string{"trigram"} + case "vector": + u.IndexSpecs = []*pb.VectorIndexSpec{ + { + Name: "hnsw", + Options: []*pb.OptionPair{ + { + Key: "metric", + Value: "cosine", + }, + }, + }, + } + default: + return uniqueConstraintExists + } + return uniqueConstraintExists +} + +func ValueToPosting_ValType(v any) (pb.Posting_ValType, error) { + switch v.(type) { + case string: + return pb.Posting_STRING, nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32: + return pb.Posting_INT, nil + case uint64: + return pb.Posting_UID, nil + case bool: + return pb.Posting_BOOL, nil + case float32, float64: + return pb.Posting_FLOAT, nil + case []byte: + return pb.Posting_BINARY, nil + case time.Time: + return pb.Posting_DATETIME, nil + case geom.Point: + return pb.Posting_GEO, nil + case []float32, []float64: + return pb.Posting_VFLOAT, nil + default: + return pb.Posting_DEFAULT, fmt.Errorf("unsupported type %T", v) + } +} + +func ValueToApiVal(v any) (*api.Value, error) { + switch val := v.(type) { + case string: + return &api.Value{Val: &api.Value_StrVal{StrVal: val}}, nil + case int: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int8: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int16: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int32: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case int64: + return &api.Value{Val: &api.Value_IntVal{IntVal: val}}, nil + case uint8: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint16: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint32: + return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil + case uint64: + return &api.Value{Val: &api.Value_UidVal{UidVal: val}}, nil + case bool: + return &api.Value{Val: &api.Value_BoolVal{BoolVal: val}}, nil + case float32: + return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: float64(val)}}, nil + case float64: + return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}}, nil + case []float32: + return &api.Value{Val: &api.Value_Vfloat32Val{ + Vfloat32Val: types.FloatArrayAsBytes(val)}}, nil + case []float64: + float32Slice := make([]float32, len(val)) + for i, v := range val { + float32Slice[i] = float32(v) + } + return &api.Value{Val: &api.Value_Vfloat32Val{ + Vfloat32Val: types.FloatArrayAsBytes(float32Slice)}}, nil + case []byte: + return &api.Value{Val: &api.Value_BytesVal{BytesVal: val}}, nil + case time.Time: + bytes, err := val.MarshalBinary() + if err != nil { + return nil, err + } + return &api.Value{Val: &api.Value_DateVal{DateVal: bytes}}, nil + case geom.Point: + bytes, err := wkb.Marshal(&val, binary.LittleEndian) + if err != nil { + return nil, err + } + return &api.Value{Val: &api.Value_GeoVal{GeoVal: bytes}}, nil + case uint: + return &api.Value{Val: &api.Value_DefaultVal{DefaultVal: fmt.Sprint(v)}}, nil + default: + return nil, fmt.Errorf("unsupported type %T", v) + } +} + +func HandleConstraints(u *pb.SchemaUpdate, jsonToDbTags map[string]*DbTag, jsonName string, valType pb.Posting_ValType, uniqueConstraintFound bool) (bool, error) { + if jsonToDbTags[jsonName] == nil { + return uniqueConstraintFound, nil + } + + constraint := jsonToDbTags[jsonName].Constraint + if constraint == "vector" && valType != pb.Posting_VFLOAT { + return false, fmt.Errorf("vector index can only be applied to []float values") + } + + return addIndex(u, constraint, uniqueConstraintFound), nil +} diff --git a/api/utils/reflect.go b/api/utils/reflect.go index c38d0f1..a4a57e0 100644 --- a/api/utils/reflect.go +++ b/api/utils/reflect.go @@ -192,3 +192,19 @@ func MapDynamicToFinal(dynamic any, final any, isNested bool) (uint64, error) { } return gid, nil } + +func ConvertDynamicToTyped[T any](obj any, t reflect.Type) (uint64, T, error) { + var result T + finalObject := reflect.New(t).Interface() + gid, err := MapDynamicToFinal(obj, finalObject, false) + if err != nil { + return 0, result, err + } + + if typedPtr, ok := finalObject.(*T); ok { + return gid, *typedPtr, nil + } else if dirType, ok := finalObject.(T); ok { + return gid, dirType, nil + } + return 0, result, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) +} diff --git a/api_mutate_helper.go b/api_mutate_helper.go deleted file mode 100644 index ffea6f8..0000000 --- a/api_mutate_helper.go +++ /dev/null @@ -1,300 +0,0 @@ -/* - * Copyright 2025 Hypermode Inc. - * Licensed under the terms of the Apache License, Version 2.0 - * See the LICENSE file that accompanied this code for further details. - * - * SPDX-FileCopyrightText: 2025 Hypermode Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package modusdb - -import ( - "context" - "fmt" - "reflect" - "strings" - - "github.com/dgraph-io/dgo/v240/protos/api" - "github.com/dgraph-io/dgraph/v24/dql" - "github.com/dgraph-io/dgraph/v24/protos/pb" - "github.com/dgraph-io/dgraph/v24/query" - "github.com/dgraph-io/dgraph/v24/schema" - "github.com/dgraph-io/dgraph/v24/worker" - "github.com/dgraph-io/dgraph/v24/x" - "github.com/hypermodeinc/modusdb/api/utils" -) - -func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object T, - gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error { - t := reflect.TypeOf(object) - if t.Kind() != reflect.Struct { - return fmt.Errorf("expected struct, got %s", t.Kind()) - } - - fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, err := utils.GetFieldTags(t) - if err != nil { - return err - } - jsonTagToValue := utils.GetJsonTagToValues(object, fieldToJsonTags) - - nquads := make([]*api.NQuad, 0) - uniqueConstraintFound := false - for jsonName, value := range jsonTagToValue { - var val *api.Value - var valType pb.Posting_ValType - - reflectValueType := reflect.TypeOf(value) - var nquad *api.NQuad - - if jsonToReverseEdgeTags[jsonName] != "" { - if reflectValueType.Kind() != reflect.Slice || reflectValueType.Elem().Kind() != reflect.Struct { - return fmt.Errorf("reverse edge %s should be a slice of structs", jsonName) - } - reverseEdge := jsonToReverseEdgeTags[jsonName] - typeName := strings.Split(reverseEdge, ".")[0] - u := &pb.SchemaUpdate{ - Predicate: utils.AddNamespace(n.id, reverseEdge), - ValueType: pb.Posting_UID, - Directive: pb.SchemaUpdate_REVERSE, - } - sch.Preds = append(sch.Preds, u) - sch.Types = append(sch.Types, &pb.TypeUpdate{ - TypeName: utils.AddNamespace(n.id, typeName), - Fields: []*pb.SchemaUpdate{u}, - }) - continue - } - if jsonName == "gid" { - uniqueConstraintFound = true - continue - } - - if reflectValueType.Kind() == reflect.Struct { - value = reflect.ValueOf(value).Interface() - newGid, err := getUidOrMutate(ctx, n.db, n, value) - if err != nil { - return err - } - value = newGid - } else if reflectValueType.Kind() == reflect.Pointer { - // dereference the pointer - reflectValueType = reflectValueType.Elem() - if reflectValueType.Kind() == reflect.Struct { - // convert value to pointer, and then dereference - value = reflect.ValueOf(value).Elem().Interface() - newGid, err := getUidOrMutate(ctx, n.db, n, value) - if err != nil { - return err - } - value = newGid - } - } - valType, err = valueToPosting_ValType(value) - if err != nil { - return err - } - val, err = valueToApiVal(value) - if err != nil { - return err - } - - nquad = &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: utils.GetPredicateName(t.Name(), jsonName), - } - - u := &pb.SchemaUpdate{ - Predicate: utils.AddNamespace(n.id, utils.GetPredicateName(t.Name(), jsonName)), - ValueType: valType, - } - - if valType == pb.Posting_UID { - nquad.ObjectId = fmt.Sprint(value) - u.Directive = pb.SchemaUpdate_REVERSE - } else { - nquad.ObjectValue = val - } - - if jsonToDbTags[jsonName] != nil { - constraint := jsonToDbTags[jsonName].Constraint - if constraint == "vector" && valType != pb.Posting_VFLOAT { - return fmt.Errorf("vector index can only be applied to []float values") - } - uniqueConstraintFound = addIndex(u, constraint, uniqueConstraintFound) - } - - sch.Preds = append(sch.Preds, u) - nquads = append(nquads, nquad) - } - if !uniqueConstraintFound { - return fmt.Errorf(utils.NoUniqueConstr, t.Name()) - } - sch.Types = append(sch.Types, &pb.TypeUpdate{ - TypeName: utils.AddNamespace(n.id, t.Name()), - Fields: sch.Preds, - }) - - val, err := valueToApiVal(t.Name()) - if err != nil { - return err - } - typeNquad := &api.NQuad{ - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: "dgraph.type", - ObjectValue: val, - } - nquads = append(nquads, typeNquad) - - *dms = append(*dms, &dql.Mutation{ - Set: nquads, - }) - - return nil -} - -func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { - return []*dql.Mutation{{ - Del: []*api.NQuad{ - { - Namespace: n.ID(), - Subject: fmt.Sprint(gid), - Predicate: x.Star, - ObjectValue: &api.Value{ - Val: &api.Value_DefaultVal{DefaultVal: x.Star}, - }, - }, - }, - }} -} - -func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { - edges, err := query.ToDirectedEdges(dms, nil) - if err != nil { - return err - } - - if !db.isOpen { - return ErrClosedDB - } - - startTs, err := db.z.nextTs() - if err != nil { - return err - } - commitTs, err := db.z.nextTs() - if err != nil { - return err - } - - m := &pb.Mutations{ - GroupId: 1, - StartTs: startTs, - Edges: edges, - } - m.Edges, err = query.ExpandEdges(ctx, m) - if err != nil { - return fmt.Errorf("error expanding edges: %w", err) - } - - p := &pb.Proposal{Mutations: m, StartTs: startTs} - if err := worker.ApplyMutations(ctx, p); err != nil { - return err - } - - return worker.ApplyCommited(ctx, &pb.OracleDelta{ - Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, - }) -} - -func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { - gid, cf, err := GetUniqueConstraint[T](object) - if err != nil { - return 0, err - } - - dms := make([]*dql.Mutation, 0) - sch := &schema.ParsedSchema{} - err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) - if err != nil { - return 0, err - } - - err = n.alterSchemaWithParsed(ctx, sch) - if err != nil { - return 0, err - } - if gid != 0 { - gid, _, err = getByGidWithObject[T](ctx, n, gid, object) - if err != nil && err != utils.ErrNoObjFound { - return 0, err - } - if err == nil { - return gid, nil - } - } else if cf != nil { - gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) - if err != nil && err != utils.ErrNoObjFound { - return 0, err - } - if err == nil { - return gid, nil - } - } - - gid, err = db.z.nextUID() - if err != nil { - return 0, err - } - - dms = make([]*dql.Mutation, 0) - err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) - if err != nil { - return 0, err - } - - err = applyDqlMutations(ctx, db, dms) - if err != nil { - return 0, err - } - - return gid, nil -} - -func addIndex(u *pb.SchemaUpdate, index string, uniqueConstraintExists bool) bool { - u.Directive = pb.SchemaUpdate_INDEX - switch index { - case "exact": - u.Tokenizer = []string{"exact"} - case "term": - u.Tokenizer = []string{"term"} - case "hash": - u.Tokenizer = []string{"hash"} - case "unique": - u.Tokenizer = []string{"exact"} - u.Unique = true - u.Upsert = true - uniqueConstraintExists = true - case "fulltext": - u.Tokenizer = []string{"fulltext"} - case "trigram": - u.Tokenizer = []string{"trigram"} - case "vector": - u.IndexSpecs = []*pb.VectorIndexSpec{ - { - Name: "hnsw", - Options: []*pb.OptionPair{ - { - Key: "metric", - Value: "cosine", - }, - }, - }, - } - default: - return uniqueConstraintExists - } - return uniqueConstraintExists -} diff --git a/api_query_helper.go b/api_query_helper.go index d49979c..1b7b726 100644 --- a/api_query_helper.go +++ b/api_query_helper.go @@ -15,7 +15,7 @@ import ( "fmt" "reflect" - "github.com/hypermodeinc/modusdb/api/dql_query" + "github.com/hypermodeinc/modusdb/api/query_gen" "github.com/hypermodeinc/modusdb/api/utils" ) @@ -57,7 +57,8 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac readFromQuery := "" if withReverse { for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { - readFromQuery += fmt.Sprintf(dql_query.ReverseEdgeQuery, utils.GetPredicateName(t.Name(), jsonTag), reverseEdgeTag) + readFromQuery += fmt.Sprintf(query_gen.ReverseEdgeQuery, + utils.GetPredicateName(t.Name(), jsonTag), reverseEdgeTag) } } @@ -65,9 +66,9 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac var query string gid, ok := any(args[0]).(uint64) if ok { - query = dql_query.FormatObjQuery(dql_query.BuildUidQuery(gid), readFromQuery) + query = query_gen.FormatObjQuery(query_gen.BuildUidQuery(gid), readFromQuery) } else if cf, ok = any(args[0]).(ConstrainedField); ok { - query = dql_query.FormatObjQuery(dql_query.BuildEqQuery(utils.GetPredicateName(t.Name(), + query = query_gen.FormatObjQuery(query_gen.BuildEqQuery(utils.GetPredicateName(t.Name(), cf.Key), cf.Value), readFromQuery) } else { return 0, obj, fmt.Errorf("invalid unique field type") @@ -102,22 +103,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac return 0, obj, utils.ErrNoObjFound } - // Map the dynamic struct to the final type T - finalObject := reflect.New(t).Interface() - gid, err = utils.MapDynamicToFinal(result.Obj[0], finalObject, false) - if err != nil { - return 0, obj, err - } - - if typedPtr, ok := finalObject.(*T); ok { - return gid, *typedPtr, nil - } - - if dirType, ok := finalObject.(T); ok { - return gid, dirType, nil - } - - return 0, obj, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) + return utils.ConvertDynamicToTyped[T](result.Obj[0], t) } func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryParams, @@ -129,7 +115,7 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar return nil, nil, err } - var filterQueryFunc dql_query.QueryFunc = func() string { + var filterQueryFunc query_gen.QueryFunc = func() string { return "" } var paginationAndSorting string @@ -150,11 +136,11 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar readFromQuery := "" if withReverse { for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { - readFromQuery += fmt.Sprintf(dql_query.ReverseEdgeQuery, utils.GetPredicateName(t.Name(), jsonTag), reverseEdgeTag) + readFromQuery += fmt.Sprintf(query_gen.ReverseEdgeQuery, utils.GetPredicateName(t.Name(), jsonTag), reverseEdgeTag) } } - query := dql_query.FormatObjsQuery(t.Name(), filterQueryFunc, paginationAndSorting, readFromQuery) + query := query_gen.FormatObjsQuery(t.Name(), filterQueryFunc, paginationAndSorting, readFromQuery) resp, err := n.queryWithLock(ctx, query) if err != nil { @@ -188,21 +174,12 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar var gids []uint64 var objs []T for _, obj := range result.Objs { - finalObject := reflect.New(t).Interface() - gid, err := utils.MapDynamicToFinal(obj, finalObject, false) + gid, typedObj, err := utils.ConvertDynamicToTyped[T](obj, t) if err != nil { return nil, nil, err } - - if typedPtr, ok := finalObject.(*T); ok { - gids = append(gids, gid) - objs = append(objs, *typedPtr) - } else if dirType, ok := finalObject.(T); ok { - gids = append(gids, gid) - objs = append(objs, dirType) - } else { - return nil, nil, fmt.Errorf("failed to convert type %T to %T", finalObject, obj) - } + gids = append(gids, gid) + objs = append(objs, typedObj) } return gids, objs, nil diff --git a/api_types.go b/api_types.go index ebcb4b2..90d0af1 100644 --- a/api_types.go +++ b/api_types.go @@ -11,19 +11,12 @@ package modusdb import ( "context" - "encoding/binary" "fmt" "strings" - "time" - "github.com/dgraph-io/dgo/v240/protos/api" - "github.com/dgraph-io/dgraph/v24/protos/pb" - "github.com/dgraph-io/dgraph/v24/types" "github.com/dgraph-io/dgraph/v24/x" - "github.com/hypermodeinc/modusdb/api/dql_query" + "github.com/hypermodeinc/modusdb/api/query_gen" "github.com/hypermodeinc/modusdb/api/utils" - "github.com/twpayne/go-geom" - "github.com/twpayne/go-geom/encoding/wkb" ) type UniqueField interface { @@ -110,144 +103,60 @@ func getDefaultNamespace(db *DB, ns ...uint64) (context.Context, *Namespace, err return ctx, n, nil } -func valueToPosting_ValType(v any) (pb.Posting_ValType, error) { - switch v.(type) { - case string: - return pb.Posting_STRING, nil - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32: - return pb.Posting_INT, nil - case uint64: - return pb.Posting_UID, nil - case bool: - return pb.Posting_BOOL, nil - case float32, float64: - return pb.Posting_FLOAT, nil - case []byte: - return pb.Posting_BINARY, nil - case time.Time: - return pb.Posting_DATETIME, nil - case geom.Point: - return pb.Posting_GEO, nil - case []float32, []float64: - return pb.Posting_VFLOAT, nil - default: - return pb.Posting_DEFAULT, fmt.Errorf("unsupported type %T", v) - } -} - -func valueToApiVal(v any) (*api.Value, error) { - switch val := v.(type) { - case string: - return &api.Value{Val: &api.Value_StrVal{StrVal: val}}, nil - case int: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int8: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int16: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int32: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case int64: - return &api.Value{Val: &api.Value_IntVal{IntVal: val}}, nil - case uint8: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case uint16: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case uint32: - return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}}, nil - case uint64: - return &api.Value{Val: &api.Value_UidVal{UidVal: val}}, nil - case bool: - return &api.Value{Val: &api.Value_BoolVal{BoolVal: val}}, nil - case float32: - return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: float64(val)}}, nil - case float64: - return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}}, nil - case []float32: - return &api.Value{Val: &api.Value_Vfloat32Val{ - Vfloat32Val: types.FloatArrayAsBytes(val)}}, nil - case []float64: - float32Slice := make([]float32, len(val)) - for i, v := range val { - float32Slice[i] = float32(v) - } - return &api.Value{Val: &api.Value_Vfloat32Val{ - Vfloat32Val: types.FloatArrayAsBytes(float32Slice)}}, nil - case []byte: - return &api.Value{Val: &api.Value_BytesVal{BytesVal: val}}, nil - case time.Time: - bytes, err := val.MarshalBinary() - if err != nil { - return nil, err - } - return &api.Value{Val: &api.Value_DateVal{DateVal: bytes}}, nil - case geom.Point: - bytes, err := wkb.Marshal(&val, binary.LittleEndian) - if err != nil { - return nil, err - } - return &api.Value{Val: &api.Value_GeoVal{GeoVal: bytes}}, nil - case uint: - return &api.Value{Val: &api.Value_DefaultVal{DefaultVal: fmt.Sprint(v)}}, nil - default: - return nil, fmt.Errorf("unsupported type %T", v) - } -} - -func filterToQueryFunc(typeName string, f Filter) dql_query.QueryFunc { +func filterToQueryFunc(typeName string, f Filter) query_gen.QueryFunc { // Handle logical operators first if f.And != nil { - return dql_query.And(filterToQueryFunc(typeName, *f.And)) + return query_gen.And(filterToQueryFunc(typeName, *f.And)) } if f.Or != nil { - return dql_query.Or(filterToQueryFunc(typeName, *f.Or)) + return query_gen.Or(filterToQueryFunc(typeName, *f.Or)) } if f.Not != nil { - return dql_query.Not(filterToQueryFunc(typeName, *f.Not)) + return query_gen.Not(filterToQueryFunc(typeName, *f.Not)) } // Handle field predicates if f.String.Equals != "" { - return dql_query.BuildEqQuery(utils.GetPredicateName(typeName, f.Field), f.String.Equals) + return query_gen.BuildEqQuery(utils.GetPredicateName(typeName, f.Field), f.String.Equals) } if len(f.String.AllOfTerms) != 0 { - return dql_query.BuildAllOfTermsQuery(utils.GetPredicateName(typeName, + return query_gen.BuildAllOfTermsQuery(utils.GetPredicateName(typeName, f.Field), strings.Join(f.String.AllOfTerms, " ")) } if len(f.String.AnyOfTerms) != 0 { - return dql_query.BuildAnyOfTermsQuery(utils.GetPredicateName(typeName, + return query_gen.BuildAnyOfTermsQuery(utils.GetPredicateName(typeName, f.Field), strings.Join(f.String.AnyOfTerms, " ")) } if len(f.String.AllOfText) != 0 { - return dql_query.BuildAllOfTextQuery(utils.GetPredicateName(typeName, + return query_gen.BuildAllOfTextQuery(utils.GetPredicateName(typeName, f.Field), strings.Join(f.String.AllOfText, " ")) } if len(f.String.AnyOfText) != 0 { - return dql_query.BuildAnyOfTextQuery(utils.GetPredicateName(typeName, + return query_gen.BuildAnyOfTextQuery(utils.GetPredicateName(typeName, f.Field), strings.Join(f.String.AnyOfText, " ")) } if f.String.RegExp != "" { - return dql_query.BuildRegExpQuery(utils.GetPredicateName(typeName, + return query_gen.BuildRegExpQuery(utils.GetPredicateName(typeName, f.Field), f.String.RegExp) } if f.String.LessThan != "" { - return dql_query.BuildLtQuery(utils.GetPredicateName(typeName, + return query_gen.BuildLtQuery(utils.GetPredicateName(typeName, f.Field), f.String.LessThan) } if f.String.LessOrEqual != "" { - return dql_query.BuildLeQuery(utils.GetPredicateName(typeName, + return query_gen.BuildLeQuery(utils.GetPredicateName(typeName, f.Field), f.String.LessOrEqual) } if f.String.GreaterThan != "" { - return dql_query.BuildGtQuery(utils.GetPredicateName(typeName, + return query_gen.BuildGtQuery(utils.GetPredicateName(typeName, f.Field), f.String.GreaterThan) } if f.String.GreaterOrEqual != "" { - return dql_query.BuildGeQuery(utils.GetPredicateName(typeName, + return query_gen.BuildGeQuery(utils.GetPredicateName(typeName, f.Field), f.String.GreaterOrEqual) } if f.Vector.SimilarTo != nil { - return dql_query.BuildSimilarToQuery(utils.GetPredicateName(typeName, + return query_gen.BuildSimilarToQuery(utils.GetPredicateName(typeName, f.Field), f.Vector.TopK, f.Vector.SimilarTo) } @@ -256,7 +165,7 @@ func filterToQueryFunc(typeName string, f Filter) dql_query.QueryFunc { } // Helper function to combine multiple filters -func filtersToQueryFunc(typeName string, filter Filter) dql_query.QueryFunc { +func filtersToQueryFunc(typeName string, filter Filter) query_gen.QueryFunc { return filterToQueryFunc(typeName, filter) } diff --git a/api.go b/apis.go similarity index 95% rename from api.go rename to apis.go index 311baf0..bfa7c7e 100644 --- a/api.go +++ b/apis.go @@ -35,7 +35,7 @@ func Create[T any](db *DB, object T, ns ...uint64) (uint64, T, error) { dms := make([]*dql.Mutation, 0) sch := &schema.ParsedSchema{} - err = generateCreateDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) + err = generateSetDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) if err != nil { return 0, object, err } @@ -74,7 +74,7 @@ func Upsert[T any](db *DB, object T, ns ...uint64) (uint64, T, bool, error) { dms := make([]*dql.Mutation, 0) sch := &schema.ParsedSchema{} - err = generateCreateDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) + err = generateSetDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) if err != nil { return 0, object, false, err } @@ -105,7 +105,7 @@ func Upsert[T any](db *DB, object T, ns ...uint64) (uint64, T, bool, error) { } dms = make([]*dql.Mutation, 0) - err = generateCreateDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) + err = generateSetDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) if err != nil { return 0, object, false, err } diff --git a/api_test.go b/apis_test.go similarity index 100% rename from api_test.go rename to apis_test.go diff --git a/mutation_create.go b/mutation_create.go new file mode 100644 index 0000000..3bfdc6f --- /dev/null +++ b/mutation_create.go @@ -0,0 +1,121 @@ +/* + * Copyright 2025 Hypermode Inc. + * Licensed under the terms of the Apache License, Version 2.0 + * See the LICENSE file that accompanied this code for further details. + * + * SPDX-FileCopyrightText: 2025 Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusdb + +import ( + "context" + "fmt" + "reflect" + + "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/dgraph-io/dgraph/v24/dql" + "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/schema" + "github.com/dgraph-io/dgraph/v24/x" + "github.com/hypermodeinc/modusdb/api/mutations" + "github.com/hypermodeinc/modusdb/api/utils" +) + +func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object T, + gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error { + t := reflect.TypeOf(object) + if t.Kind() != reflect.Struct { + return fmt.Errorf("expected struct, got %s", t.Kind()) + } + + fieldToJsonTags, jsonToDbTags, jsonToReverseEdgeTags, err := utils.GetFieldTags(t) + if err != nil { + return err + } + jsonTagToValue := utils.GetJsonTagToValues(object, fieldToJsonTags) + + nquads := make([]*api.NQuad, 0) + uniqueConstraintFound := false + for jsonName, value := range jsonTagToValue { + + reflectValueType := reflect.TypeOf(value) + var nquad *api.NQuad + + if jsonToReverseEdgeTags[jsonName] != "" { + if err := mutations.HandleReverseEdge(jsonName, reflectValueType, n.id, sch, jsonToReverseEdgeTags); err != nil { + return err + } + continue + } + if jsonName == "gid" { + uniqueConstraintFound = true + continue + } + + value, err = processStructValue(ctx, value, n) + if err != nil { + return err + } + + value, err = processPointerValue(ctx, value, n) + if err != nil { + return err + } + + nquad, u, err := mutations.CreateNQuadAndSchema(value, gid, jsonName, t, n.ID()) + if err != nil { + return err + } + + uniqueConstraintFound, err = utils.HandleConstraints(u, jsonToDbTags, jsonName, u.ValueType, uniqueConstraintFound) + if err != nil { + return err + } + + sch.Preds = append(sch.Preds, u) + nquads = append(nquads, nquad) + } + if !uniqueConstraintFound { + return fmt.Errorf(utils.NoUniqueConstr, t.Name()) + } + + sch.Types = append(sch.Types, &pb.TypeUpdate{ + TypeName: utils.AddNamespace(n.id, t.Name()), + Fields: sch.Preds, + }) + + val, err := utils.ValueToApiVal(t.Name()) + if err != nil { + return err + } + typeNquad := &api.NQuad{ + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: "dgraph.type", + ObjectValue: val, + } + nquads = append(nquads, typeNquad) + + *dms = append(*dms, &dql.Mutation{ + Set: nquads, + }) + + return nil +} + +func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { + return []*dql.Mutation{{ + Del: []*api.NQuad{ + { + Namespace: n.ID(), + Subject: fmt.Sprint(gid), + Predicate: x.Star, + ObjectValue: &api.Value{ + Val: &api.Value_DefaultVal{DefaultVal: x.Star}, + }, + }, + }, + }} +} diff --git a/mutation_helpers.go b/mutation_helpers.go new file mode 100644 index 0000000..899e61b --- /dev/null +++ b/mutation_helpers.go @@ -0,0 +1,131 @@ +package modusdb + +import ( + "context" + "fmt" + "reflect" + + "github.com/dgraph-io/dgraph/v24/dql" + "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/query" + "github.com/dgraph-io/dgraph/v24/schema" + "github.com/dgraph-io/dgraph/v24/worker" + "github.com/hypermodeinc/modusdb/api/utils" +) + +func processStructValue(ctx context.Context, value any, n *Namespace) (any, error) { + if reflect.TypeOf(value).Kind() == reflect.Struct { + value = reflect.ValueOf(value).Interface() + newGid, err := getUidOrMutate(ctx, n.db, n, value) + if err != nil { + return nil, err + } + return newGid, nil + } + return value, nil +} + +func processPointerValue(ctx context.Context, value any, n *Namespace) (any, error) { + reflectValueType := reflect.TypeOf(value) + if reflectValueType.Kind() == reflect.Pointer { + reflectValueType = reflectValueType.Elem() + if reflectValueType.Kind() == reflect.Struct { + value = reflect.ValueOf(value).Elem().Interface() + return processStructValue(ctx, value, n) + } + } + return value, nil +} + +func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) (uint64, error) { + gid, cf, err := GetUniqueConstraint[T](object) + if err != nil { + return 0, err + } + + dms := make([]*dql.Mutation, 0) + sch := &schema.ParsedSchema{} + err = generateSetDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + if err != nil { + return 0, err + } + + err = n.alterSchemaWithParsed(ctx, sch) + if err != nil { + return 0, err + } + if gid != 0 { + gid, _, err = getByGidWithObject[T](ctx, n, gid, object) + if err != nil && err != utils.ErrNoObjFound { + return 0, err + } + if err == nil { + return gid, nil + } + } else if cf != nil { + gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) + if err != nil && err != utils.ErrNoObjFound { + return 0, err + } + if err == nil { + return gid, nil + } + } + + gid, err = db.z.nextUID() + if err != nil { + return 0, err + } + + dms = make([]*dql.Mutation, 0) + err = generateSetDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + if err != nil { + return 0, err + } + + err = applyDqlMutations(ctx, db, dms) + if err != nil { + return 0, err + } + + return gid, nil +} + +func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { + edges, err := query.ToDirectedEdges(dms, nil) + if err != nil { + return err + } + + if !db.isOpen { + return ErrClosedDB + } + + startTs, err := db.z.nextTs() + if err != nil { + return err + } + commitTs, err := db.z.nextTs() + if err != nil { + return err + } + + m := &pb.Mutations{ + GroupId: 1, + StartTs: startTs, + Edges: edges, + } + m.Edges, err = query.ExpandEdges(ctx, m) + if err != nil { + return fmt.Errorf("error expanding edges: %w", err) + } + + p := &pb.Proposal{Mutations: m, StartTs: startTs} + if err := worker.ApplyMutations(ctx, p); err != nil { + return err + } + + return worker.ApplyCommited(ctx, &pb.OracleDelta{ + Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, + }) +} From e1d42348e1c755eac34f4bf22ffd0ca103443903 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 3 Jan 2025 09:59:26 -0800 Subject: [PATCH 12/17] . --- api/mutations/mutations.go | 6 ++++-- api/utils/dgraph.go | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/api/mutations/mutations.go b/api/mutations/mutations.go index 1217499..4e4fedf 100644 --- a/api/mutations/mutations.go +++ b/api/mutations/mutations.go @@ -11,7 +11,8 @@ import ( "github.com/hypermodeinc/modusdb/api/utils" ) -func HandleReverseEdge(jsonName string, value reflect.Type, nsId uint64, sch *schema.ParsedSchema, jsonToReverseEdgeTags map[string]string) error { +func HandleReverseEdge(jsonName string, value reflect.Type, nsId uint64, sch *schema.ParsedSchema, + jsonToReverseEdgeTags map[string]string) error { if jsonToReverseEdgeTags[jsonName] == "" { return nil } @@ -36,7 +37,8 @@ func HandleReverseEdge(jsonName string, value reflect.Type, nsId uint64, sch *sc return nil } -func CreateNQuadAndSchema(value any, gid uint64, jsonName string, t reflect.Type, nsId uint64) (*api.NQuad, *pb.SchemaUpdate, error) { +func CreateNQuadAndSchema(value any, gid uint64, jsonName string, t reflect.Type, + nsId uint64) (*api.NQuad, *pb.SchemaUpdate, error) { valType, err := utils.ValueToPosting_ValType(value) if err != nil { return nil, nil, err diff --git a/api/utils/dgraph.go b/api/utils/dgraph.go index c89cf0c..fc49050 100644 --- a/api/utils/dgraph.go +++ b/api/utils/dgraph.go @@ -132,7 +132,8 @@ func ValueToApiVal(v any) (*api.Value, error) { } } -func HandleConstraints(u *pb.SchemaUpdate, jsonToDbTags map[string]*DbTag, jsonName string, valType pb.Posting_ValType, uniqueConstraintFound bool) (bool, error) { +func HandleConstraints(u *pb.SchemaUpdate, jsonToDbTags map[string]*DbTag, jsonName string, + valType pb.Posting_ValType, uniqueConstraintFound bool) (bool, error) { if jsonToDbTags[jsonName] == nil { return uniqueConstraintFound, nil } From b4942d2574bf64a62cc9f18741214c8a12619e3f Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 3 Jan 2025 10:04:28 -0800 Subject: [PATCH 13/17] . --- apis.go => api.go | 0 apis_test.go => api_test.go | 0 api_query_helper.go => query_execution.go | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename apis.go => api.go (100%) rename apis_test.go => api_test.go (100%) rename api_query_helper.go => query_execution.go (100%) diff --git a/apis.go b/api.go similarity index 100% rename from apis.go rename to api.go diff --git a/apis_test.go b/api_test.go similarity index 100% rename from apis_test.go rename to api_test.go diff --git a/api_query_helper.go b/query_execution.go similarity index 100% rename from api_query_helper.go rename to query_execution.go From 5a1ef51570c586a8030015d81f9f06bfe5ad8711 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 3 Jan 2025 10:06:47 -0800 Subject: [PATCH 14/17] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d92b85..e92097c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ - feat: add readfrom json tag to support reverse edges [#49](https://github.com/hypermodeinc/modusDB/pull/49) +- chore: Refactoring package management #51 [#51](https://github.com/hypermodeinc/modusDB/pull/51) + ## 2025-01-02 - Version 0.1.0 Baseline for the changelog. From b483510d812d42c46de8e64b26dd5bf40386f482 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 3 Jan 2025 10:07:55 -0800 Subject: [PATCH 15/17] . --- query_execution.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/query_execution.go b/query_execution.go index 1b7b726..3a0133b 100644 --- a/query_execution.go +++ b/query_execution.go @@ -171,15 +171,15 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar return nil, nil, err } - var gids []uint64 - var objs []T - for _, obj := range result.Objs { + gids := make([]uint64, len(result.Objs)) + objs := make([]T, len(result.Objs)) + for i, obj := range result.Objs { gid, typedObj, err := utils.ConvertDynamicToTyped[T](obj, t) if err != nil { return nil, nil, err } - gids = append(gids, gid) - objs = append(objs, typedObj) + gids[i] = gid + objs[i] = typedObj } return gids, objs, nil From 0cbe92488c758f942f088b4e75112a99ed144225 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 3 Jan 2025 10:19:10 -0800 Subject: [PATCH 16/17] . --- api.go | 11 +++-------- mutation_create.go => api_mutation_gen.go | 0 mutation_helpers.go | 12 ++---------- query_execution.go | 13 +++++++++++++ 4 files changed, 18 insertions(+), 18 deletions(-) rename mutation_create.go => api_mutation_gen.go (100%) diff --git a/api.go b/api.go index bfa7c7e..beadf9f 100644 --- a/api.go +++ b/api.go @@ -84,19 +84,14 @@ func Upsert[T any](db *DB, object T, ns ...uint64) (uint64, T, bool, error) { return 0, object, false, err } - if gid != 0 { - gid, _, err = getByGidWithObject[T](ctx, n, gid, object) - if err != nil && err != utils.ErrNoObjFound { - return 0, object, false, err - } - wasFound = err == nil - } else if cf != nil { - gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) + if gid != 0 || cf != nil { + gid, err = getExistingObject[T](ctx, n, gid, cf, object) if err != nil && err != utils.ErrNoObjFound { return 0, object, false, err } wasFound = err == nil } + if gid == 0 { gid, err = db.z.nextUID() if err != nil { diff --git a/mutation_create.go b/api_mutation_gen.go similarity index 100% rename from mutation_create.go rename to api_mutation_gen.go diff --git a/mutation_helpers.go b/mutation_helpers.go index 899e61b..d4b0cf4 100644 --- a/mutation_helpers.go +++ b/mutation_helpers.go @@ -54,16 +54,8 @@ func getUidOrMutate[T any](ctx context.Context, db *DB, n *Namespace, object T) if err != nil { return 0, err } - if gid != 0 { - gid, _, err = getByGidWithObject[T](ctx, n, gid, object) - if err != nil && err != utils.ErrNoObjFound { - return 0, err - } - if err == nil { - return gid, nil - } - } else if cf != nil { - gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) + if gid != 0 || cf != nil { + gid, err = getExistingObject(ctx, n, gid, cf, object) if err != nil && err != utils.ErrNoObjFound { return 0, err } diff --git a/query_execution.go b/query_execution.go index 3a0133b..dc77b66 100644 --- a/query_execution.go +++ b/query_execution.go @@ -184,3 +184,16 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar return gids, objs, nil } + +func getExistingObject[T any](ctx context.Context, n *Namespace, gid uint64, cf *ConstrainedField, object T) (uint64, error) { + var err error + if gid != 0 { + gid, _, err = getByGidWithObject[T](ctx, n, gid, object) + } else if cf != nil { + gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) + } + if err != nil { + return 0, err + } + return gid, nil +} From ea882aaa6a4c5abb6c1e3be7b5ce5e44c348e529 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 3 Jan 2025 10:20:51 -0800 Subject: [PATCH 17/17] . --- mutation_helpers.go => api_mutation_helpers.go | 0 query_execution.go => api_query_execution.go | 3 ++- 2 files changed, 2 insertions(+), 1 deletion(-) rename mutation_helpers.go => api_mutation_helpers.go (100%) rename query_execution.go => api_query_execution.go (98%) diff --git a/mutation_helpers.go b/api_mutation_helpers.go similarity index 100% rename from mutation_helpers.go rename to api_mutation_helpers.go diff --git a/query_execution.go b/api_query_execution.go similarity index 98% rename from query_execution.go rename to api_query_execution.go index dc77b66..bf9cb23 100644 --- a/query_execution.go +++ b/api_query_execution.go @@ -185,7 +185,8 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar return gids, objs, nil } -func getExistingObject[T any](ctx context.Context, n *Namespace, gid uint64, cf *ConstrainedField, object T) (uint64, error) { +func getExistingObject[T any](ctx context.Context, n *Namespace, gid uint64, cf *ConstrainedField, + object T) (uint64, error) { var err error if gid != 0 { gid, _, err = getByGidWithObject[T](ctx, n, gid, object)