From cb412234ad6e44bc054f819c7039bf9c0de67de0 Mon Sep 17 00:00:00 2001 From: zoryamba Date: Mon, 27 Oct 2025 21:45:41 +0200 Subject: [PATCH 1/2] Add WhereHas method --- contracts/database/driver/conditions.go | 13 +-- contracts/database/orm/orm.go | 2 + database/gorm/query.go | 115 +++++++++++++++++++++++- database/gorm/utils.go | 23 ++++- support/database/database.go | 8 ++ 5 files changed, 151 insertions(+), 10 deletions(-) diff --git a/contracts/database/driver/conditions.go b/contracts/database/driver/conditions.go index 22a65d2a0..059946695 100644 --- a/contracts/database/driver/conditions.go +++ b/contracts/database/driver/conditions.go @@ -8,6 +8,8 @@ const ( WhereTypeJsonContains WhereTypeJsonContainsKey WhereTypeJsonLength + // WhereRelation used for where cause with relation subqueries, like WhereHas, OrWhereHas etc. + WhereRelation ) type Conditions struct { @@ -40,9 +42,10 @@ type Join struct { } type Where struct { - Query any - Args []any - Type WhereType - Or bool - IsNot bool + Query any + Args []any + Type WhereType + Relation string + Or bool + IsNot bool } diff --git a/contracts/database/orm/orm.go b/contracts/database/orm/orm.go index eec3eebdb..494f78489 100644 --- a/contracts/database/orm/orm.go +++ b/contracts/database/orm/orm.go @@ -181,6 +181,8 @@ type Query interface { UpdateOrCreate(dest any, attributes any, values any) error // Where add a "where" clause to the query. Where(query any, args ...any) Query + // WhereHas add a relationship count / exists condition to the query with where clauses. + WhereHas(relation string, callback func(query Query) Query, args ...any) Query // WhereBetween adds a "where column between x and y" clause to the query. WhereBetween(column string, x, y any) Query // WhereIn adds a "where column in" clause to the query. diff --git a/database/gorm/query.go b/database/gorm/query.go index f799fae74..5fc4cdbdf 100644 --- a/database/gorm/query.go +++ b/database/gorm/query.go @@ -12,6 +12,7 @@ import ( "github.com/spf13/cast" gormio "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" "github.com/goravel/framework/contracts/config" contractsdatabase "github.com/goravel/framework/contracts/database" @@ -41,6 +42,8 @@ type Query struct { mutex sync.Mutex } +type subqueryCallback = func(contractsorm.Query) contractsorm.Query + func NewQuery( ctx context.Context, config config.Config, @@ -765,7 +768,7 @@ func (r *Query) Scan(dest any) error { return query.instance.Scan(dest).Error } -func (r *Query) Scopes(funcs ...func(contractsorm.Query) contractsorm.Query) contractsorm.Query { +func (r *Query) Scopes(funcs ...subqueryCallback) contractsorm.Query { conditions := r.conditions conditions.scopes = deep.Append(r.conditions.scopes, funcs...) @@ -1026,6 +1029,21 @@ func (r *Query) WhereNotNull(column string) contractsorm.Query { return r.Where(fmt.Sprintf("%s IS NOT NULL", column)) } +func (r *Query) WhereHas(relation string, callback subqueryCallback, args ...any) contractsorm.Query { + subquery := NewQuery(r.ctx, r.config, r.dbConfig, r.instance, r.grammar, r.log, r.modelToObserver, nil) + + if callback != nil { + subquery = callback(subquery).(*Query) + } + + return r.addWhere(contractsdriver.Where{ + Query: subquery, + Args: args, + Relation: relation, + Type: contractsdriver.WhereRelation, + }) +} + func (r *Query) With(query string, args ...any) contractsorm.Query { conditions := r.conditions conditions.with = deep.Append(r.conditions.with, With{ @@ -1295,7 +1313,7 @@ func (r *Query) buildSharedLock(db *gormio.DB) *gormio.DB { return db } -func (r *Query) buildSubquery(sub func(contractsorm.Query) contractsorm.Query) *gormio.DB { +func (r *Query) buildSubquery(sub subqueryCallback) *gormio.DB { db := r.instance.Session(&gormio.Session{NewDB: true, Initialized: true}) queryImpl := NewQuery(r.ctx, r.config, r.dbConfig, db, r.grammar, r.log, r.modelToObserver, nil) query := sub(queryImpl) @@ -1341,9 +1359,11 @@ func (r *Query) buildWhere(db *gormio.DB) *gormio.DB { segments := strings.SplitN(item.Query.(string), " ", 2) segments[0] = r.grammar.CompileJsonLength(segments[0]) item.Query = r.buildWherePlaceholder(strings.Join(segments, " "), item.Args...) + case contractsdriver.WhereRelation: + item.Query, item.Args = r.buildWhereRelation(item) default: switch query := item.Query.(type) { - case func(contractsorm.Query) contractsorm.Query: + case subqueryCallback: item.Query = r.buildSubquery(query) item.Args = nil case string: @@ -1372,6 +1392,93 @@ func (r *Query) buildWhere(db *gormio.DB) *gormio.DB { return db } +func (r *Query) buildWhereRelation(item contractsdriver.Where) (any, []any) { + const ( + gt = ">" + gte = ">=" + lte = ">=" + lt = "<" + ) + + var ( + op string + count int64 + ) + + if argsLen := len(item.Args); argsLen > 0 && argsLen != 1 { + o, ok := item.Args[0].(string) + if !ok { + return r.instance.AddError(errors.New("the first argument should be string, it uses as operator")), []any{} + } + + c, err := cast.ToInt64E(item.Args[1]) + if err != nil { + return r.instance.AddError(errors.New("the second argument should be int64, it uses as count")), []any{} + } + + op = o + count = c + } + + subquery, err := r.relationSubquery(item.Relation, item.Query.(*Query)) + if err != nil { + return r.instance.AddError(err), []any{} + } + + needCountQuery := !((count == 0 && slices.Contains([]string{lt, lte, gt, gte}, op)) || op == "") + + if !needCountQuery { + fmt.Println("exists") + modifiedQueryImpl := subquery.(*Query).buildConditions().instance + return "EXISTS (?)", []any{modifiedQueryImpl} + } + + modifiedQueryImpl := subquery.Select("count(*)").(*Query) + return "(?) " + op + " ?", []any{modifiedQueryImpl.buildConditions().instance, count} +} + +func (r *Query) relationSubquery(relation string, subquery contractsorm.Query) (contractsorm.Query, error) { + mSchema, err := getModelSchema(r.conditions.model, r.instance) + if err != nil { + return nil, fmt.Errorf("faild to get model schema, the model should be set before using this method. %w", err) + } + + fmt.Println(relation) + rel, ok := mSchema.Relationships.Relations[relation] + if !ok { + return nil, fmt.Errorf("relation not found. %s", relation) + } + relModel := getZeroValueFromReflectType(rel.Field.FieldType) + + subquery = subquery.Model(relModel) + + fmt.Printf("%+v\n", subquery) + + fk := rel.References[0].ForeignKey.DBName + ft := rel.FieldSchema.Table + table := mSchema.Table + + switch rel.Type { + case schema.BelongsTo: + pk := rel.FieldSchema.PrioritizedPrimaryField.DBName + subquery = subquery.Where(database.QuoteConcat(ft, pk) + " = " + database.QuoteConcat(table, fk)) + case schema.HasOne, schema.HasMany: + pk := mSchema.PrioritizedPrimaryField.DBName + subquery = subquery.Where(database.QuoteConcat(ft, fk) + " = " + database.QuoteConcat(table, pk)) + case schema.Many2Many: + joinTable := rel.JoinTable.Table + pk := mSchema.PrioritizedPrimaryField.DBName + subquery = subquery. + Join("inner join " + + database.Quote(joinTable) + + " on " + + database.QuoteConcat(mSchema.Table, pk) + + " = " + database.QuoteConcat(joinTable, fk)) + } + + return subquery, nil +} + func (r *Query) buildWherePlaceholder(query string, args ...any) string { // if query does not contain a placeholder,it might be incorrectly quoted or treated as an expression // to avoid errors, append a manual placeholder @@ -1393,7 +1500,7 @@ func (r *Query) buildWith(db *gormio.DB) *gormio.DB { for _, item := range r.conditions.with { isSet := false if len(item.args) == 1 { - if arg, ok := item.args[0].(func(contractsorm.Query) contractsorm.Query); ok { + if arg, ok := item.args[0].(subqueryCallback); ok { newArgs := []any{ func(tx *gormio.DB) *gormio.DB { queryImpl := NewQuery(r.ctx, r.config, r.dbConfig, tx, r.grammar, r.log, r.modelToObserver, nil) diff --git a/database/gorm/utils.go b/database/gorm/utils.go index fb3cf9448..9950d536a 100644 --- a/database/gorm/utils.go +++ b/database/gorm/utils.go @@ -1,6 +1,11 @@ package gorm -import "reflect" +import ( + "reflect" + + gormio "gorm.io/gorm" + "gorm.io/gorm/schema" +) func copyStruct(dest any) reflect.Value { t := reflect.TypeOf(dest) @@ -18,3 +23,19 @@ func copyStruct(dest any) reflect.Value { return v.Convert(copyDestStruct) } + +func getZeroValueFromReflectType(t reflect.Type) any { + if t.Kind() == reflect.Pointer { + return reflect.New(t.Elem()).Interface() + } + return reflect.New(t.Elem()).Interface() +} + +func getModelSchema(model any, db *gormio.DB) (*schema.Schema, error) { + stmt := gormio.Statement{DB: db} + err := stmt.Parse(model) + if err != nil { + return nil, err + } + return stmt.Schema, nil +} diff --git a/support/database/database.go b/support/database/database.go index 1a44fdcae..8795fee43 100644 --- a/support/database/database.go +++ b/support/database/database.go @@ -43,3 +43,11 @@ func GetIDByReflect(t reflect.Type, v reflect.Value) any { return nil } + +func Quote(name string) string { + return "`" + name + "`" +} + +func QuoteConcat(table string, col string) string { + return Quote(table) + "." + Quote(col) +} From 48e574d684070572c18a6e81e2088e8c65bcb41f Mon Sep 17 00:00:00 2001 From: zoryamba Date: Tue, 18 Nov 2025 01:32:48 +0200 Subject: [PATCH 2/2] Add WhereHas tests --- contracts/database/orm/orm.go | 4 +-- database/gorm/operator.go | 33 ++++++++++++++++++ database/gorm/query.go | 13 ++------ database/gorm/query_test.go | 63 +++++++++++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 12 deletions(-) create mode 100644 database/gorm/operator.go diff --git a/contracts/database/orm/orm.go b/contracts/database/orm/orm.go index 494f78489..e1406cc64 100644 --- a/contracts/database/orm/orm.go +++ b/contracts/database/orm/orm.go @@ -181,10 +181,10 @@ type Query interface { UpdateOrCreate(dest any, attributes any, values any) error // Where add a "where" clause to the query. Where(query any, args ...any) Query - // WhereHas add a relationship count / exists condition to the query with where clauses. - WhereHas(relation string, callback func(query Query) Query, args ...any) Query // WhereBetween adds a "where column between x and y" clause to the query. WhereBetween(column string, x, y any) Query + // WhereHas add a relationship count / exists condition to the query with where clauses. + WhereHas(relation string, callback func(query Query) Query, args ...any) Query // WhereIn adds a "where column in" clause to the query. WhereIn(column string, values []any) Query // WhereJsonContains add a "where JSON contains" clause to the query. diff --git a/database/gorm/operator.go b/database/gorm/operator.go new file mode 100644 index 000000000..ce1213559 --- /dev/null +++ b/database/gorm/operator.go @@ -0,0 +1,33 @@ +package gorm + +type Operator = string + +const ( + gt Operator = ">" + gte Operator = ">=" + eq Operator = "=" + lte Operator = "<=" + lt Operator = "<" +) + +func isAnyOperator(s any) (Operator, bool) { + o, ok := s.(string) + if !ok { + return "", false + } + + if isOperator(o) { + return o, true + } + + return "", false +} + +func isOperator(s string) bool { + switch s { + case gt, gte, eq, lte, lt: + return true + default: + return false + } +} diff --git a/database/gorm/query.go b/database/gorm/query.go index 5fc4cdbdf..dbdb09fb4 100644 --- a/database/gorm/query.go +++ b/database/gorm/query.go @@ -1393,20 +1393,13 @@ func (r *Query) buildWhere(db *gormio.DB) *gormio.DB { } func (r *Query) buildWhereRelation(item contractsdriver.Where) (any, []any) { - const ( - gt = ">" - gte = ">=" - lte = ">=" - lt = "<" - ) - var ( - op string + op Operator count int64 ) if argsLen := len(item.Args); argsLen > 0 && argsLen != 1 { - o, ok := item.Args[0].(string) + o, ok := isAnyOperator(item.Args[0]) if !ok { return r.instance.AddError(errors.New("the first argument should be string, it uses as operator")), []any{} } @@ -1425,7 +1418,7 @@ func (r *Query) buildWhereRelation(item contractsdriver.Where) (any, []any) { return r.instance.AddError(err), []any{} } - needCountQuery := !((count == 0 && slices.Contains([]string{lt, lte, gt, gte}, op)) || op == "") + needCountQuery := !((count == 0 && slices.Contains([]Operator{lt, lte, gt, gte}, op)) || op == "") if !needCountQuery { fmt.Println("exists") diff --git a/database/gorm/query_test.go b/database/gorm/query_test.go index c5b95b178..0b153b114 100644 --- a/database/gorm/query_test.go +++ b/database/gorm/query_test.go @@ -71,6 +71,69 @@ func TestAddWhere(t *testing.T) { }, query1.conditions.where) } +func TestAddWhereHas(t *testing.T) { + type Organization struct { + Model + Name string + } + type Post struct { + Model + Title string + } + type Role struct { + Model + Title string + } + type User struct { + Model + Name string + Posts []*Post + Organization *Organization + Roles []*Role `gorm:"many2many:role_user"` + } + + query := (&Query{}). + Model(&User{}). + WhereHas("Posts", nil). + WhereHas("Organization", func(q contractsorm.Query) contractsorm.Query { return q.Where("name", "John") }). + WhereHas("Roles", nil, ">=", 10) + + assert.Equal(t, &Query{ + conditions: Conditions{ + where: []contractsdriver.Where{ + contractsdriver.Where{ + Query: &Query{queries: make(map[string]*Query)}, + Args: nil, + Relation: "Posts", + Type: contractsdriver.WhereRelation, + }, + contractsdriver.Where{ + Query: &Query{ + queries: make(map[string]*Query), + conditions: Conditions{ + where: []contractsdriver.Where{contractsdriver.Where{ + Query: "name", + Args: []any{"John"}, + }}, + }, + }, + Args: nil, + Relation: "Organization", + Type: contractsdriver.WhereRelation, + }, + contractsdriver.Where{ + Query: &Query{queries: make(map[string]*Query)}, + Args: []any{">=", 10}, + Relation: "Roles", + Type: contractsdriver.WhereRelation, + }, + }, + model: &User{}, + }, + queries: make(map[string]*Query), + }, query) +} + func TestGetObserver(t *testing.T) { query := &Query{ modelToObserver: []contractsorm.ModelToObserver{