diff --git a/entgql/internal/todo/todo_test.go b/entgql/internal/todo/todo_test.go index 7e2136c7b..e7a9814e8 100644 --- a/entgql/internal/todo/todo_test.go +++ b/entgql/internal/todo/todo_test.go @@ -2698,6 +2698,56 @@ func TestOrderByEdgeCount(t *testing.T) { } }) + t.Run("MultiOrderWithPagination", func(t *testing.T) { + var ( + // language=GraphQL + query = `query CategoryByTodosCount($first: Int, $after: Cursor) { + categories( + first: $first, + after: $after + orderBy: [{field: TODOS_COUNT, direction: DESC}], + ) { + edges { + cursor + node { + id + text + } + } + } + }` + rsp struct { + Categories struct { + Edges []struct { + Cursor string + Node struct { + ID string + Text string + } + } + } + } + ) + gqlc.MustPost( + query, + &rsp, + client.Var("first", 2), + client.Var("after", nil), + ) + require.Len(t, rsp.Categories.Edges, 2) + + // Do another query to get the next node after the first in our original query. + expectedNode := rsp.Categories.Edges[1].Node + gqlc.MustPost( + query, + &rsp, + client.Var("first", 1), + client.Var("after", rsp.Categories.Edges[0].Cursor), + ) + require.Len(t, rsp.Categories.Edges, 1) + require.Equal(t, expectedNode.ID, rsp.Categories.Edges[0].Node.ID) + }) + t.Run("NestedEdgeCountOrdering", func(t *testing.T) { var ( // language=GraphQL diff --git a/entgql/internal/todogotype/generated.go b/entgql/internal/todogotype/generated.go index 7b7091bad..c1e5faddc 100644 --- a/entgql/internal/todogotype/generated.go +++ b/entgql/internal/todogotype/generated.go @@ -20473,7 +20473,7 @@ func (ec *executionContext) marshalOInt2áš–int(ctx context.Context, sel ast.Sele return res } -func (ec *executionContext) unmarshalOMap2map(ctx context.Context, v any) (map[string]any, error) { +func (ec *executionContext) unmarshalOMap2map(ctx context.Context, v any) (map[string]interface{}, error) { if v == nil { return nil, nil } @@ -20481,7 +20481,7 @@ func (ec *executionContext) unmarshalOMap2map(ctx context.Context, v any) (map[s return res, graphql.ErrorOnPath(ctx, err) } -func (ec *executionContext) marshalOMap2map(ctx context.Context, sel ast.SelectionSet, v map[string]any) graphql.Marshaler { +func (ec *executionContext) marshalOMap2map(ctx context.Context, sel ast.SelectionSet, v map[string]interface{}) graphql.Marshaler { if v == nil { return graphql.Null } diff --git a/entgql/pagination.go b/entgql/pagination.go index 6e536f42a..c392f8458 100644 --- a/entgql/pagination.go +++ b/entgql/pagination.go @@ -146,7 +146,7 @@ func CursorsPredicate[T any](after, before *Cursor[T], idField, field string, di s.Where(sql.P(func(b *sql.Builder) { // The predicate function is executed on query generation time. column := s.C(field) - // If there is a non-ambiguis match, we use it. That is because + // If there is a non-ambiguous match, we use it. That is because // some order terms may append joined information to query selection. if matches := s.FindSelection(field); len(matches) == 1 { column = matches[0] @@ -218,16 +218,32 @@ func multiPredicate[T any](cursor *Cursor[T], opts *MultiCursorsOptions) (func(* return func(s *sql.Selector) { // Given the following terms: x DESC, y ASC, etc. The following predicate will be // generated: (x < x1 OR (x = x1 AND y > y1) OR (x = x1 AND y = y1 AND id > last)). + + // getColumnNameForField gets the name for the term and considers non-ambigous matching of + // terms that may be joined instead of a column on the table. + getColumnNameForField := func(field string) string { + // The predicate function is executed on query generation time. + column := s.C(field) + // If there is a non-ambiguous match, we use it. That is because + // some order terms may append joined information to query selection. + if matches := s.FindSelection(field); len(matches) == 1 { + column = matches[0] + } + return column + } + var or []*sql.Predicate for i := range opts.Fields { var ands []*sql.Predicate for j := 0; j < i; j++ { - ands = append(ands, sql.EQ(s.C(opts.Fields[j]), values[j])) + c := getColumnNameForField(opts.Fields[j]) + ands = append(ands, sql.EQ(c, values[j])) } + c := getColumnNameForField(opts.Fields[i]) if opts.Directions[i] == OrderDirectionAsc { - ands = append(ands, sql.GT(s.C(opts.Fields[i]), values[i])) + ands = append(ands, sql.GT(c, values[i])) } else { - ands = append(ands, sql.LT(s.C(opts.Fields[i]), values[i])) + ands = append(ands, sql.LT(c, values[i])) } or = append(or, sql.And(ands...)) }