From 446cc07760056dd6e50341c9aea685dcef73e763 Mon Sep 17 00:00:00 2001 From: chayandass Date: Mon, 24 Mar 2025 22:32:38 +0530 Subject: [PATCH] fix(obligation_creation)- add aftercreate hook Association query (#121) Signed-off-by: chayandass --- pkg/models/types.go | 51 ++++++++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/pkg/models/types.go b/pkg/models/types.go index e235dbd2..c40d95d3 100644 --- a/pkg/models/types.go +++ b/pkg/models/types.go @@ -59,6 +59,7 @@ func (LicenseDB) TableName() string { // BeforeCreate hook to validate data and log the user who is creating the record func (l *LicenseDB) BeforeCreate(tx *gorm.DB) (err error) { + username, ok := tx.Statement.Context.Value(ContextKey("user")).(string) if !ok { return errors.New("username not found in context") @@ -449,6 +450,7 @@ type Obligation struct { Type *ObligationType `gorm:"foreignKey:ObligationTypeId; references:Id"` Classification *ObligationClassification `gorm:"foreignKey:ObligationClassificationId ;references:Id"` Category *string `json:"category" gorm:"default:GENERAL" enums:"DISTRIBUTION,PATENT,INTERNAL,CONTRACTUAL,EXPORT_CONTROL,GENERAL" example:"DISTRIBUTION"` + Shortnames []string `gorm:"-" json:"-"` // Ignore in GORM and JSON responses (Temporary) } func (Obligation) TableName() string { @@ -532,12 +534,40 @@ func (o *Obligation) BeforeCreate(tx *gorm.DB) (err error) { return err } - for i := 0; i < len(o.Licenses); i++ { + for i := 0; i < len(o.Shortnames); i++ { + var license LicenseDB + if err := tx.Where(LicenseDB{Shortname: &o.Shortnames[i]}).First(&license).Error; err != nil { + return fmt.Errorf("license with shortname %s not found", o.Shortnames[i]) + } + } + return nil +} +func (o *Obligation) AfterCreate(tx *gorm.DB) (err error) { + + if len(o.Shortnames) == 0 { + return nil + } + + var licenses []*LicenseDB + + for i := 0; i < len(o.Shortnames); i++ { var license LicenseDB - if err := tx.Where(LicenseDB{Shortname: o.Licenses[i].Shortname}).First(&license).Error; err != nil { - return fmt.Errorf("license with shortname %s not found", *o.Licenses[i].Shortname) + if err := tx.Where(LicenseDB{Shortname: &o.Shortnames[i]}).First(&license).Error; err != nil { + return fmt.Errorf("license with shortname %s not found", o.Shortnames[i]) + } + licenses = append(licenses, &license) + } + + if len(licenses) == 0 { + return fmt.Errorf("no licenses found for the given shortnames") + } + + // insert realtion in obligation_licenses table + for _, license := range licenses { + query := `INSERT INTO obligation_licenses (obligation_id, license_db_id) VALUES ($1, $2)` + if err := tx.Exec(query, o.Id, license.Id).Error; err != nil { + return fmt.Errorf("failed to associate license ID %d: %v", license.Id, err) } - o.Licenses[i] = &license } return nil @@ -641,10 +671,8 @@ func (o *Obligation) MarshalJSON() ([]byte, error) { defaultCategory := "GENERAL" ob.Category = &defaultCategory } + ob.Shortnames = o.Shortnames - for i := 0; i < len(o.Licenses); i++ { - ob.Shortnames = append(ob.Shortnames, *o.Licenses[i].Shortname) - } return json.Marshal(ob) } @@ -681,12 +709,7 @@ func (o *Obligation) UnmarshalJSON(data []byte) error { } } - o.Licenses = []*LicenseDB{} - for i := 0; i < len(dto.Shortnames); i++ { - o.Licenses = append(o.Licenses, &LicenseDB{ - Shortname: &dto.Shortnames[i], - }) - } + o.Shortnames = dto.Shortnames return nil } @@ -744,7 +767,7 @@ type ObligationPreview struct { Type string `json:"type" enums:"obligation,restriction,risk,right"` } -// ObligationResponse represents the response format for obligation data. +// ObligationPreviewResponse represents the response format for obligation data. type ObligationPreviewResponse struct { Status int `json:"status" example:"200"` Data []ObligationPreview `json:"data"`