Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions pkg/models/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func (LicenseDB) TableName() string {

// BeforeCreate hook to validate data and log the user who is creating the record
func (l *LicenseDB) BeforeCreate(tx *gorm.DB) (err error) {

username, ok := tx.Statement.Context.Value(ContextKey("user")).(string)
if !ok {
return errors.New("username not found in context")
Expand Down Expand Up @@ -449,6 +450,7 @@ type Obligation struct {
Type *ObligationType `gorm:"foreignKey:ObligationTypeId; references:Id"`
Classification *ObligationClassification `gorm:"foreignKey:ObligationClassificationId ;references:Id"`
Category *string `json:"category" gorm:"default:GENERAL" enums:"DISTRIBUTION,PATENT,INTERNAL,CONTRACTUAL,EXPORT_CONTROL,GENERAL" example:"DISTRIBUTION"`
Shortnames []string `gorm:"-" json:"-"` // Ignore in GORM and JSON responses (Temporary)
}

func (Obligation) TableName() string {
Expand Down Expand Up @@ -532,12 +534,40 @@ func (o *Obligation) BeforeCreate(tx *gorm.DB) (err error) {
return err
}

for i := 0; i < len(o.Licenses); i++ {
for i := 0; i < len(o.Shortnames); i++ {
var license LicenseDB
if err := tx.Where(LicenseDB{Shortname: &o.Shortnames[i]}).First(&license).Error; err != nil {
return fmt.Errorf("license with shortname %s not found", o.Shortnames[i])
}
}
return nil
}
func (o *Obligation) AfterCreate(tx *gorm.DB) (err error) {

if len(o.Shortnames) == 0 {
return nil
}

var licenses []*LicenseDB

for i := 0; i < len(o.Shortnames); i++ {
var license LicenseDB
if err := tx.Where(LicenseDB{Shortname: o.Licenses[i].Shortname}).First(&license).Error; err != nil {
return fmt.Errorf("license with shortname %s not found", *o.Licenses[i].Shortname)
if err := tx.Where(LicenseDB{Shortname: &o.Shortnames[i]}).First(&license).Error; err != nil {
return fmt.Errorf("license with shortname %s not found", o.Shortnames[i])
}
licenses = append(licenses, &license)
}

if len(licenses) == 0 {
return fmt.Errorf("no licenses found for the given shortnames")
}

// insert realtion in obligation_licenses table
for _, license := range licenses {
query := `INSERT INTO obligation_licenses (obligation_id, license_db_id) VALUES ($1, $2)`
if err := tx.Exec(query, o.Id, license.Id).Error; err != nil {
return fmt.Errorf("failed to associate license ID %d: %v", license.Id, err)
}
o.Licenses[i] = &license
}

return nil
Expand Down Expand Up @@ -641,10 +671,8 @@ func (o *Obligation) MarshalJSON() ([]byte, error) {
defaultCategory := "GENERAL"
ob.Category = &defaultCategory
}
ob.Shortnames = o.Shortnames

for i := 0; i < len(o.Licenses); i++ {
ob.Shortnames = append(ob.Shortnames, *o.Licenses[i].Shortname)
}
return json.Marshal(ob)
}

Expand Down Expand Up @@ -681,12 +709,7 @@ func (o *Obligation) UnmarshalJSON(data []byte) error {
}
}

o.Licenses = []*LicenseDB{}
for i := 0; i < len(dto.Shortnames); i++ {
o.Licenses = append(o.Licenses, &LicenseDB{
Shortname: &dto.Shortnames[i],
})
}
o.Shortnames = dto.Shortnames

return nil
}
Expand Down Expand Up @@ -744,7 +767,7 @@ type ObligationPreview struct {
Type string `json:"type" enums:"obligation,restriction,risk,right"`
}

// ObligationResponse represents the response format for obligation data.
// ObligationPreviewResponse represents the response format for obligation data.
type ObligationPreviewResponse struct {
Status int `json:"status" example:"200"`
Data []ObligationPreview `json:"data"`
Expand Down
Loading