Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion internal/ast/compiler/disjunctions.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ func (pass *DisjunctionToType) processDisjunction(visitor *Visitor, schema *ast.
processedBranch := branch
processedBranch.Nullable = true

fields = append(fields, ast.NewStructField(ast.TypeName(processedBranch), processedBranch))
structField := ast.NewStructField(ast.TypeName(processedBranch), processedBranch)
fields = append(fields, structField)
}

structType := ast.NewStruct(fields...)
Expand All @@ -130,6 +131,8 @@ func (pass *DisjunctionToType) processDisjunction(visitor *Visitor, schema *ast.
structType.Hints[ast.HintDiscriminatedDisjunctionOfRefs] = disjunction
}

structType.Default = def.Default

newObject := ast.NewObject(schema.Package, newTypeName, structType)
newObject.AddToPassesTrail("DisjunctionToType[created]")

Expand Down
25 changes: 25 additions & 0 deletions internal/ast/compiler/disjunctions_infer_mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func (pass *DisjunctionInferMapping) processDisjunction(_ *Visitor, schema *ast.
return def, nil
}

def.Default = pass.parseDefault(def)
return def, nil
}

Expand Down Expand Up @@ -150,6 +151,10 @@ func (pass *DisjunctionInferMapping) buildDiscriminatorMapping(schema *ast.Schem
return nil, fmt.Errorf("could not resolve reference '%s'", branch.AsRef().String())
}

if !referredType.IsStruct() {
continue
}

structType := referredType.AsStruct()

field, found := structType.FieldByName(def.Discriminator)
Expand All @@ -176,3 +181,23 @@ func (pass *DisjunctionInferMapping) buildDiscriminatorMapping(schema *ast.Schem

return mapping, nil
}

func (pass *DisjunctionInferMapping) parseDefault(t ast.Type) any {
if t.Default == nil {
return nil
}

disjunction := t.Disjunction

defs := t.Default.(map[string]interface{})
for _, value := range defs {
if _, ok := value.(string); !ok {
continue
}
if referenceName, ok := disjunction.DiscriminatorMapping[value.(string)]; ok {
return referenceName
}
}

return t.Default
}
5 changes: 4 additions & 1 deletion internal/jennies/golang/rawtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast
(field.Required && field.Type.IsArray()) ||
(field.Required && field.Type.IsMap()) ||
field.Type.IsConcreteScalar() ||
field.Type.IsConstantRef()
field.Type.IsConstantRef() ||
(objectType.Default != nil && field.Name == objectType.Default)
if !needsExplicitDefault {
continue
}
Expand Down Expand Up @@ -295,6 +296,8 @@ func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast
defaultValue = formatScalar(field.Type.Default)

defaultValue = jenny.maybeValueAsPointer(defaultValue, field.Type.Nullable, resolvedFieldType)
} else if objectType.Default != nil && objectType.IsRef() {
defaultValue = fmt.Sprintf("New%s()", field.Name)
} else if field.Type.IsRef() && resolvedFieldType.IsStruct() && field.Type.Default != nil {
defaultValue = jenny.defaultsForStruct(context, *field.Type.Ref, resolvedFieldType, field.Type.Default)
if field.Type.Nullable {
Expand Down
1 change: 1 addition & 0 deletions internal/jennies/typescript/jennies.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ func (language *Language) Jennies(globalConfig languages.Config) *codejen.JennyL
func (language *Language) CompilerPasses() compiler.Passes {
return compiler.Passes{
&compiler.RenameNumericEnumValues{},
&compiler.DisjunctionInferMapping{},
}
}

Expand Down
8 changes: 8 additions & 0 deletions internal/jennies/typescript/rawtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ func (jenny RawTypes) defaultValueForObject(object ast.Object, packageMapper pac
}

return raw(jenny.typeFormatter.enums.formatValue(object, defaultValue))
case ast.KindDisjunction:
if object.Type.Default != nil {
if object.Type.AsDisjunction().Branches[0].IsRef() {
return raw(fmt.Sprintf("default%s()", object.Type.Default))
}
return object.Type.Default
}
fallthrough
default:
return jenny.defaultValueForType(object.Type, packageMapper)
}
Expand Down
22 changes: 1 addition & 21 deletions internal/simplecue/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,27 +548,7 @@ func (g *generator) declareDisjunction(v cue.Value, hints ast.JenniesHints, defa
return g.declareAnonymousEnum(v, defaultValue, hints)
}

_, disjunctionBranchesWithPossibleDefault := v.Expr()
defaultAsCueValue, hasDefault := v.Default()

disjunctionBranches := make([]cue.Value, 0, len(disjunctionBranchesWithPossibleDefault))
for _, branch := range disjunctionBranchesWithPossibleDefault {
if hasDefault && branch.Equals(defaultAsCueValue) {
_, bPath := branch.ReferencePath()
_, dPath := defaultAsCueValue.ReferencePath()

if bPath.String() == dPath.String() {
continue
}
}

disjunctionBranches = append(disjunctionBranches, branch)
}

// not a disjunction anymore
if len(disjunctionBranchesWithPossibleDefault) != len(disjunctionBranches) && len(disjunctionBranches) == 1 {
return g.declareNode(disjunctionBranches[0])
}
_, disjunctionBranches := v.Expr()

// We must be looking at a disjunction then (2)
branches := make([]ast.Type, 0, len(disjunctionBranches))
Expand Down
Loading
Loading