From 1cdd289564c6430fd4d5d5e976c6e7bdad829793 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 7 Feb 2025 09:51:11 -0800 Subject: [PATCH 1/2] add context to apis --- api.go | 26 +++++++----- api_types.go | 3 +- unit_test/api_test.go | 97 ++++++++++++++++++++++--------------------- 3 files changed, 67 insertions(+), 59 deletions(-) diff --git a/api.go b/api.go index 98eda6d..4bbd0d7 100644 --- a/api.go +++ b/api.go @@ -6,6 +6,7 @@ package modusdb import ( + "context" "fmt" "github.com/hypermodeinc/dgraph/v24/dql" @@ -14,13 +15,14 @@ import ( "github.com/hypermodeinc/modusdb/api/structreflect" ) -func Create[T any](engine *Engine, object T, nsId ...uint64) (uint64, T, error) { +func Create[T any](ctx context.Context, engine *Engine, object T, + nsId ...uint64) (uint64, T, error) { engine.mutex.Lock() defer engine.mutex.Unlock() if len(nsId) > 1 { return 0, object, fmt.Errorf("only one namespace is allowed") } - ctx, ns, err := getDefaultNamespace(engine, nsId...) + ctx, ns, err := getDefaultNamespace(ctx, engine, nsId...) if err != nil { return 0, object, err } @@ -50,7 +52,8 @@ func Create[T any](engine *Engine, object T, nsId ...uint64) (uint64, T, error) return getByGid[T](ctx, ns, gid) } -func Upsert[T any](engine *Engine, object T, nsId ...uint64) (uint64, T, bool, error) { +func Upsert[T any](ctx context.Context, engine *Engine, object T, + nsId ...uint64) (uint64, T, bool, error) { var wasFound bool engine.mutex.Lock() @@ -59,7 +62,7 @@ func Upsert[T any](engine *Engine, object T, nsId ...uint64) (uint64, T, bool, e return 0, object, false, fmt.Errorf("only one namespace is allowed") } - ctx, ns, err := getDefaultNamespace(engine, nsId...) + ctx, ns, err := getDefaultNamespace(ctx, engine, nsId...) if err != nil { return 0, object, false, err } @@ -122,14 +125,15 @@ func Upsert[T any](engine *Engine, object T, nsId ...uint64) (uint64, T, bool, e return gid, object, wasFound, nil } -func Get[T any, R UniqueField](engine *Engine, uniqueField R, nsId ...uint64) (uint64, T, error) { +func Get[T any, R UniqueField](ctx context.Context, engine *Engine, uniqueField R, + nsId ...uint64) (uint64, T, error) { engine.mutex.Lock() defer engine.mutex.Unlock() var obj T if len(nsId) > 1 { return 0, obj, fmt.Errorf("only one namespace is allowed") } - ctx, ns, err := getDefaultNamespace(engine, nsId...) + ctx, ns, err := getDefaultNamespace(ctx, engine, nsId...) if err != nil { return 0, obj, err } @@ -144,13 +148,14 @@ func Get[T any, R UniqueField](engine *Engine, uniqueField R, nsId ...uint64) (u return 0, obj, fmt.Errorf("invalid unique field type") } -func Query[T any](engine *Engine, queryParams QueryParams, nsId ...uint64) ([]uint64, []T, error) { +func Query[T any](ctx context.Context, engine *Engine, queryParams QueryParams, + nsId ...uint64) ([]uint64, []T, error) { engine.mutex.Lock() defer engine.mutex.Unlock() if len(nsId) > 1 { return nil, nil, fmt.Errorf("only one namespace is allowed") } - ctx, ns, err := getDefaultNamespace(engine, nsId...) + ctx, ns, err := getDefaultNamespace(ctx, engine, nsId...) if err != nil { return nil, nil, err } @@ -158,14 +163,15 @@ func Query[T any](engine *Engine, queryParams QueryParams, nsId ...uint64) ([]ui return executeQuery[T](ctx, ns, queryParams, true) } -func Delete[T any, R UniqueField](engine *Engine, uniqueField R, nsId ...uint64) (uint64, T, error) { +func Delete[T any, R UniqueField](ctx context.Context, engine *Engine, uniqueField R, + nsId ...uint64) (uint64, T, error) { engine.mutex.Lock() defer engine.mutex.Unlock() var zeroObj T if len(nsId) > 1 { return 0, zeroObj, fmt.Errorf("only one namespace is allowed") } - ctx, ns, err := getDefaultNamespace(engine, nsId...) + ctx, ns, err := getDefaultNamespace(ctx, engine, nsId...) if err != nil { return 0, zeroObj, err } diff --git a/api_types.go b/api_types.go index 728da61..b8acdac 100644 --- a/api_types.go +++ b/api_types.go @@ -80,7 +80,7 @@ func WithNamespace(ns uint64) ModusDbOption { } } -func getDefaultNamespace(engine *Engine, nsId ...uint64) (context.Context, *Namespace, error) { +func getDefaultNamespace(ctx context.Context, engine *Engine, nsId ...uint64) (context.Context, *Namespace, error) { dbOpts := &modusDbOptions{ ns: engine.db0.ID(), } @@ -93,7 +93,6 @@ func getDefaultNamespace(engine *Engine, nsId ...uint64) (context.Context, *Name return nil, nil, err } - ctx := context.Background() ctx = x.AttachNamespace(ctx, d.ID()) return ctx, d, nil diff --git a/unit_test/api_test.go b/unit_test/api_test.go index 9ec75d3..a404547 100644 --- a/unit_test/api_test.go +++ b/unit_test/api_test.go @@ -28,7 +28,7 @@ func TestFirstTimeUser(t *testing.T) { require.NoError(t, err) defer engine.Close() - gid, user, err := modusdb.Create(engine, User{ + gid, user, err := modusdb.Create(context.Background(), engine, User{ Name: "A", Age: 10, ClerkId: "123", @@ -40,7 +40,7 @@ func TestFirstTimeUser(t *testing.T) { require.Equal(t, 10, user.Age) require.Equal(t, "123", user.ClerkId) - gid, queriedUser, err := modusdb.Get[User](engine, gid) + gid, queriedUser, err := modusdb.Get[User](context.Background(), engine, gid) require.NoError(t, err) require.Equal(t, queriedUser.Gid, gid) @@ -48,7 +48,7 @@ func TestFirstTimeUser(t *testing.T) { require.Equal(t, "A", queriedUser.Name) require.Equal(t, "123", queriedUser.ClerkId) - gid, queriedUser2, err := modusdb.Get[User](engine, modusdb.ConstrainedField{ + gid, queriedUser2, err := modusdb.Get[User](context.Background(), engine, modusdb.ConstrainedField{ Key: "clerk_id", Value: "123", }) @@ -59,10 +59,10 @@ func TestFirstTimeUser(t *testing.T) { require.Equal(t, "A", queriedUser2.Name) require.Equal(t, "123", queriedUser2.ClerkId) - _, _, err = modusdb.Delete[User](engine, gid) + _, _, err = modusdb.Delete[User](context.Background(), engine, gid) require.NoError(t, err) - _, queriedUser3, err := modusdb.Get[User](engine, gid) + _, queriedUser3, err := modusdb.Get[User](context.Background(), engine, gid) require.Error(t, err) require.Equal(t, "no object found", err.Error()) require.Equal(t, queriedUser3, User{}) @@ -86,7 +86,7 @@ func TestCreateApi(t *testing.T) { ClerkId: "123", } - gid, user, err := modusdb.Create(engine, user, ns1.ID()) + gid, user, err := modusdb.Create(context.Background(), engine, user, ns1.ID()) require.NoError(t, err) require.Equal(t, "B", user.Name) @@ -122,7 +122,7 @@ func TestCreateApiWithNonStruct(t *testing.T) { Age: 20, } - _, _, err = modusdb.Create[*User](engine, &user, ns1.ID()) + _, _, err = modusdb.Create[*User](context.Background(), engine, &user, ns1.ID()) require.Error(t, err) require.Equal(t, "expected struct, got ptr", err.Error()) } @@ -144,10 +144,10 @@ func TestGetApi(t *testing.T) { ClerkId: "123", } - gid, _, err := modusdb.Create(engine, user, ns1.ID()) + gid, _, err := modusdb.Create(context.Background(), engine, user, ns1.ID()) require.NoError(t, err) - gid, queriedUser, err := modusdb.Get[User](engine, gid, ns1.ID()) + gid, queriedUser, err := modusdb.Get[User](context.Background(), engine, gid, ns1.ID()) require.NoError(t, err) require.Equal(t, queriedUser.Gid, gid) @@ -173,10 +173,10 @@ func TestGetApiWithConstrainedField(t *testing.T) { ClerkId: "123", } - _, _, err = modusdb.Create(engine, user, ns1.ID()) + _, _, err = modusdb.Create(context.Background(), engine, user, ns1.ID()) require.NoError(t, err) - gid, queriedUser, err := modusdb.Get[User](engine, modusdb.ConstrainedField{ + gid, queriedUser, err := modusdb.Get[User](context.Background(), engine, modusdb.ConstrainedField{ Key: "clerk_id", Value: "123", }, ns1.ID()) @@ -205,18 +205,18 @@ func TestDeleteApi(t *testing.T) { ClerkId: "123", } - gid, _, err := modusdb.Create(engine, user, ns1.ID()) + gid, _, err := modusdb.Create(context.Background(), engine, user, ns1.ID()) require.NoError(t, err) - _, _, err = modusdb.Delete[User](engine, gid, ns1.ID()) + _, _, err = modusdb.Delete[User](context.Background(), engine, gid, ns1.ID()) require.NoError(t, err) - _, queriedUser, err := modusdb.Get[User](engine, gid, ns1.ID()) + _, queriedUser, err := modusdb.Get[User](context.Background(), engine, gid, ns1.ID()) require.Error(t, err) require.Equal(t, "no object found", err.Error()) require.Equal(t, queriedUser, User{}) - _, queriedUser, err = modusdb.Get[User](engine, modusdb.ConstrainedField{ + _, queriedUser, err = modusdb.Get[User](context.Background(), engine, modusdb.ConstrainedField{ Key: "clerk_id", Value: "123", }, ns1.ID()) @@ -242,16 +242,16 @@ func TestUpsertApi(t *testing.T) { ClerkId: "123", } - gid, user, _, err := modusdb.Upsert(engine, user, ns1.ID()) + gid, user, _, err := modusdb.Upsert(context.Background(), engine, user, ns1.ID()) require.NoError(t, err) require.Equal(t, user.Gid, gid) user.Age = 21 - gid, _, _, err = modusdb.Upsert(engine, user, ns1.ID()) + gid, _, _, err = modusdb.Upsert(context.Background(), engine, user, ns1.ID()) require.NoError(t, err) require.Equal(t, user.Gid, gid) - _, queriedUser, err := modusdb.Get[User](engine, gid, ns1.ID()) + _, queriedUser, err := modusdb.Get[User](context.Background(), engine, gid, ns1.ID()) require.NoError(t, err) require.Equal(t, user.Gid, queriedUser.Gid) require.Equal(t, 21, queriedUser.Age) @@ -279,11 +279,11 @@ func TestQueryApi(t *testing.T) { } for _, user := range users { - _, _, err = modusdb.Create(engine, user, ns1.ID()) + _, _, err = modusdb.Create(context.Background(), engine, user, ns1.ID()) require.NoError(t, err) } - gids, queriedUsers, err := modusdb.Query[User](engine, modusdb.QueryParams{}, ns1.ID()) + gids, queriedUsers, err := modusdb.Query[User](context.Background(), engine, modusdb.QueryParams{}, ns1.ID()) require.NoError(t, err) require.Len(t, queriedUsers, 5) require.Len(t, gids, 5) @@ -293,7 +293,7 @@ func TestQueryApi(t *testing.T) { require.Equal(t, "D", queriedUsers[3].Name) require.Equal(t, "E", queriedUsers[4].Name) - gids, queriedUsers, err = modusdb.Query[User](engine, modusdb.QueryParams{ + gids, queriedUsers, err = modusdb.Query[User](context.Background(), engine, modusdb.QueryParams{ Filter: &modusdb.Filter{ Field: "age", String: modusdb.StringPredicate{ @@ -334,11 +334,11 @@ func TestQueryApiWithPaginiationAndSorting(t *testing.T) { } for _, user := range users { - _, _, err = modusdb.Create(engine, user, ns1.ID()) + _, _, err = modusdb.Create(context.Background(), engine, user, ns1.ID()) require.NoError(t, err) } - gids, queriedUsers, err := modusdb.Query[User](engine, modusdb.QueryParams{ + gids, queriedUsers, err := modusdb.Query[User](context.Background(), engine, modusdb.QueryParams{ Filter: &modusdb.Filter{ Field: "age", String: modusdb.StringPredicate{ @@ -358,7 +358,7 @@ func TestQueryApiWithPaginiationAndSorting(t *testing.T) { require.Equal(t, "D", queriedUsers[1].Name) require.Equal(t, "E", queriedUsers[2].Name) - gids, queriedUsers, err = modusdb.Query[User](engine, modusdb.QueryParams{ + gids, queriedUsers, err = modusdb.Query[User](context.Background(), engine, modusdb.QueryParams{ Pagination: &modusdb.Pagination{ Limit: 3, Offset: 1, @@ -401,7 +401,7 @@ func TestReverseEdgeGet(t *testing.T) { require.NoError(t, ns1.DropData(ctx)) - projGid, project, err := modusdb.Create(engine, Project{ + projGid, project, err := modusdb.Create(context.Background(), engine, Project{ Name: "P", ClerkId: "456", Branches: []Branch{ @@ -425,7 +425,7 @@ func TestReverseEdgeGet(t *testing.T) { }, } - branch1Gid, branch1, err := modusdb.Create(engine, branch1, ns1.ID()) + branch1Gid, branch1, err := modusdb.Create(context.Background(), engine, branch1, ns1.ID()) require.NoError(t, err) require.Equal(t, "B", branch1.Name) @@ -441,13 +441,13 @@ func TestReverseEdgeGet(t *testing.T) { }, } - branch2Gid, branch2, err := modusdb.Create(engine, branch2, ns1.ID()) + branch2Gid, branch2, err := modusdb.Create(context.Background(), engine, branch2, ns1.ID()) require.NoError(t, err) require.Equal(t, "B2", branch2.Name) require.Equal(t, branch2.Gid, branch2Gid) require.Equal(t, projGid, branch2.Proj.Gid) - getProjGid, queriedProject, err := modusdb.Get[Project](engine, projGid, ns1.ID()) + getProjGid, queriedProject, err := modusdb.Get[Project](context.Background(), engine, projGid, ns1.ID()) require.NoError(t, err) require.Equal(t, projGid, getProjGid) require.Equal(t, "P", queriedProject.Name) @@ -455,7 +455,8 @@ func TestReverseEdgeGet(t *testing.T) { require.Equal(t, "B", queriedProject.Branches[0].Name) require.Equal(t, "B2", queriedProject.Branches[1].Name) - queryBranchesGids, queriedBranches, err := modusdb.Query[Branch](engine, modusdb.QueryParams{}, ns1.ID()) + queryBranchesGids, queriedBranches, err := modusdb.Query[Branch](context.Background(), engine, + modusdb.QueryParams{}, ns1.ID()) require.NoError(t, err) require.Len(t, queriedBranches, 2) require.Len(t, queryBranchesGids, 2) @@ -465,10 +466,11 @@ func TestReverseEdgeGet(t *testing.T) { // max depth is 2, so we should not see the branches within project require.Len(t, queriedBranches[0].Proj.Branches, 0) - _, _, err = modusdb.Delete[Project](engine, projGid, ns1.ID()) + _, _, err = modusdb.Delete[Project](context.Background(), engine, projGid, ns1.ID()) require.NoError(t, err) - queryBranchesGids, queriedBranches, err = modusdb.Query[Branch](engine, modusdb.QueryParams{}, ns1.ID()) + queryBranchesGids, queriedBranches, err = modusdb.Query[Branch](context.Background(), engine, + modusdb.QueryParams{}, ns1.ID()) require.NoError(t, err) require.Len(t, queriedBranches, 2) require.Len(t, queryBranchesGids, 2) @@ -496,7 +498,7 @@ func TestReverseEdgeQuery(t *testing.T) { clerkCounter := 100 for _, project := range projects { - projGid, project, err := modusdb.Create(engine, project, ns1.ID()) + projGid, project, err := modusdb.Create(context.Background(), engine, project, ns1.ID()) require.NoError(t, err) require.Equal(t, project.Name, project.Name) require.Equal(t, project.Gid, projGid) @@ -509,7 +511,7 @@ func TestReverseEdgeQuery(t *testing.T) { clerkCounter += 2 for _, branch := range branches { - branchGid, branch, err := modusdb.Create(engine, branch, ns1.ID()) + branchGid, branch, err := modusdb.Create(context.Background(), engine, branch, ns1.ID()) require.NoError(t, err) require.Equal(t, branch.Name, branch.Name) require.Equal(t, branch.Gid, branchGid) @@ -517,7 +519,8 @@ func TestReverseEdgeQuery(t *testing.T) { } } - queriedProjectsGids, queriedProjects, err := modusdb.Query[Project](engine, modusdb.QueryParams{}, ns1.ID()) + queriedProjectsGids, queriedProjects, err := modusdb.Query[Project](context.Background(), + engine, modusdb.QueryParams{}, ns1.ID()) require.NoError(t, err) require.Len(t, queriedProjects, 2) require.Len(t, queriedProjectsGids, 2) @@ -551,7 +554,7 @@ func TestNestedObjectMutation(t *testing.T) { }, } - gid, branch, err := modusdb.Create(engine, branch, ns1.ID()) + gid, branch, err := modusdb.Create(context.Background(), engine, branch, ns1.ID()) require.NoError(t, err) require.Equal(t, "B", branch.Name) @@ -578,7 +581,7 @@ func TestNestedObjectMutation(t *testing.T) { {"uid":"0x3","Project.name":"P","Project.clerk_id":"456"}}]}`, string(resp.GetJson())) - gid, queriedBranch, err := modusdb.Get[Branch](engine, gid, ns1.ID()) + gid, queriedBranch, err := modusdb.Get[Branch](context.Background(), engine, gid, ns1.ID()) require.NoError(t, err) require.Equal(t, queriedBranch.Gid, gid) require.Equal(t, "B", queriedBranch.Name) @@ -596,7 +599,7 @@ func TestLinkingObjectsByConstrainedFields(t *testing.T) { require.NoError(t, ns1.DropData(ctx)) - projGid, project, err := modusdb.Create(engine, Project{ + projGid, project, err := modusdb.Create(context.Background(), engine, Project{ Name: "P", ClerkId: "456", }, ns1.ID()) @@ -614,7 +617,7 @@ func TestLinkingObjectsByConstrainedFields(t *testing.T) { }, } - gid, branch, err := modusdb.Create(engine, branch, ns1.ID()) + gid, branch, err := modusdb.Create(context.Background(), engine, branch, ns1.ID()) require.NoError(t, err) require.Equal(t, "B", branch.Name) @@ -641,7 +644,7 @@ func TestLinkingObjectsByConstrainedFields(t *testing.T) { {"uid":"0x2","Project.name":"P","Project.clerk_id":"456"}}]}`, string(resp.GetJson())) - gid, queriedBranch, err := modusdb.Get[Branch](engine, gid, ns1.ID()) + gid, queriedBranch, err := modusdb.Get[Branch](context.Background(), engine, gid, ns1.ID()) require.NoError(t, err) require.Equal(t, queriedBranch.Gid, gid) require.Equal(t, "B", queriedBranch.Name) @@ -659,7 +662,7 @@ func TestLinkingObjectsByGid(t *testing.T) { require.NoError(t, ns1.DropData(ctx)) - projGid, project, err := modusdb.Create(engine, Project{ + projGid, project, err := modusdb.Create(context.Background(), engine, Project{ Name: "P", ClerkId: "456", }, ns1.ID()) @@ -676,7 +679,7 @@ func TestLinkingObjectsByGid(t *testing.T) { }, } - gid, branch, err := modusdb.Create(engine, branch, ns1.ID()) + gid, branch, err := modusdb.Create(context.Background(), engine, branch, ns1.ID()) require.NoError(t, err) require.Equal(t, "B", branch.Name) @@ -703,7 +706,7 @@ func TestLinkingObjectsByGid(t *testing.T) { "Branch.proj":{"uid":"0x2","Project.name":"P","Project.clerk_id":"456"}}]}`, string(resp.GetJson())) - gid, queriedBranch, err := modusdb.Get[Branch](engine, gid, ns1.ID()) + gid, queriedBranch, err := modusdb.Get[Branch](context.Background(), engine, gid, ns1.ID()) require.NoError(t, err) require.Equal(t, queriedBranch.Gid, gid) require.Equal(t, "B", queriedBranch.Name) @@ -742,7 +745,7 @@ func TestNestedObjectMutationWithBadType(t *testing.T) { }, } - _, _, err = modusdb.Create(engine, branch, ns1.ID()) + _, _, err = modusdb.Create(context.Background(), engine, branch, ns1.ID()) require.Error(t, err) require.Equal(t, fmt.Sprintf(apiutils.NoUniqueConstr, "BadProject"), err.Error()) @@ -751,7 +754,7 @@ func TestNestedObjectMutationWithBadType(t *testing.T) { ClerkId: "456", } - _, _, err = modusdb.Create(engine, proj, ns1.ID()) + _, _, err = modusdb.Create(context.Background(), engine, proj, ns1.ID()) require.Error(t, err) require.Equal(t, fmt.Sprintf(apiutils.NoUniqueConstr, "BadProject"), err.Error()) @@ -785,7 +788,7 @@ func TestVectorIndexSearchTyped(t *testing.T) { } for _, doc := range documents { - _, _, err = modusdb.Create(engine, doc, ns1.ID()) + _, _, err = modusdb.Create(context.Background(), engine, doc, ns1.ID()) require.NoError(t, err) } @@ -850,11 +853,11 @@ func TestVectorIndexSearchWithQuery(t *testing.T) { } for _, doc := range documents { - _, _, err = modusdb.Create(engine, doc, ns1.ID()) + _, _, err = modusdb.Create(context.Background(), engine, doc, ns1.ID()) require.NoError(t, err) } - gids, docs, err := modusdb.Query[Document](engine, modusdb.QueryParams{ + gids, docs, err := modusdb.Query[Document](context.Background(), engine, modusdb.QueryParams{ Filter: &modusdb.Filter{ Field: "textVec", Vector: modusdb.VectorPredicate{ From d3059afcba91cfd926905b7295d9fae9660be908 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 7 Feb 2025 09:52:52 -0800 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 993afa0..4939f1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ - chore: Update dgraph dependency [#62](https://github.com/hypermodeinc/modusDB/pull/62) +- fix: add context to api functions [#69](https://github.com/hypermodeinc/modusDB/pull/69) + ## 2025-01-02 - Version 0.1.0 Baseline for the changelog.