Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/generate/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func populateResponseType(name string, r *openapi3.Response) ([]TypeTemplate, []
continue
}

tt, et := populateTypeTemplates(respName, s.Value, "")
tt, et := populateTypeTemplates(respName, s.Value, "", "")
types = append(types, tt...)
enumTypes = append(enumTypes, et...)

Expand Down
224 changes: 172 additions & 52 deletions internal/generate/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,25 @@ type TypeTemplate struct {
Type string
// Fields holds the information for the field
Fields []TypeFields

// DiscriminatorKey is the name of the discriminator key used to determine the type of a complex oneOf field: commonly "type", "kind", etc.
DiscriminatorKey string
// DiscriminatorField
DiscriminatorField string
// DiscriminatorType is the generated type name of the discriminator field, and should be constructed as "{{.OneOfName}}Type": MetricType, ValueArrayType, etc.
DiscriminatorType string
// DiscriminatorMappings maps oneOf enum constant (e.g. DatumTypeBool) to their concrete types (e.g. DatumBool).
DiscriminatorMappings []DiscriminatorMapping

VariantField string
VariantInterface string
}

// DiscriminatorMapping maps a discriminator enum value to its concrete type
type DiscriminatorMapping struct {
EnumConstant string // The enum constant to match in switch (e.g., "DatumTypeBool")
ConcreteType string // The concrete type to unmarshal into (e.g., "DatumBool")
ObjectType string
}

// TypeFields holds the information for each type field
Expand Down Expand Up @@ -256,7 +275,7 @@ func constructTypes(schemas openapi3.Schemas) ([]TypeTemplate, []EnumTemplate) {

// Set name as a valid Go type name
name = strcase.ToCamel(name)
typeTpl, enumTpl := populateTypeTemplates(name, s.Value, "")
typeTpl, enumTpl := populateTypeTemplates(name, s.Value, "", "")
typeCollection = append(typeCollection, typeTpl...)
enumCollection = append(enumCollection, enumTpl...)
}
Expand Down Expand Up @@ -326,7 +345,12 @@ func writeTypes(f *os.File, typeCollection []TypeTemplate, typeValidationCollect

fmt.Fprintf(f, "%s\n", splitDocString(tt.Description))
fmt.Fprintf(f, "type %s %s", tt.Name, tt.Type)
if tt.Fields != nil {

if tt.Type == "interface" {
fmt.Fprintf(f, " {\n")
fmt.Fprintf(f, "\tis%s()\n", tt.Name)
fmt.Fprint(f, "}\n")
} else if tt.Fields != nil {
fmt.Fprint(f, " {\n")
for _, ft := range tt.Fields {
if ft.Description != "" {
Expand All @@ -336,6 +360,51 @@ func writeTypes(f *os.File, typeCollection []TypeTemplate, typeValidationCollect
}
fmt.Fprint(f, "}\n")
}

// Write custom UnmarshalJSON method for oneOf types.
if tt.DiscriminatorKey != "" && tt.VariantField != "" {
fmt.Fprintf(f, "func (v *%s) UnmarshalJSON(data []byte) error {\n", tt.Name)

// Check the discriminator to decide which type to unmarshal to.
fmt.Fprintf(f, "\tvar peek struct {\n")
fmt.Fprintf(f, "\t\tDiscriminator %s `json:\"%s\"`\n", tt.DiscriminatorType, tt.DiscriminatorKey)
fmt.Fprintf(f, "\t\tValue json.RawMessage `json:\"%s\"`\n", strings.ToLower(tt.VariantField))
fmt.Fprintf(f, "\t}\n")
fmt.Fprintf(f, "\tif err := json.Unmarshal(data, &peek); err != nil {\n")
fmt.Fprintf(f, "\t\treturn err\n")
fmt.Fprintf(f, "\t}\n")
fmt.Fprintf(f, "\tswitch peek.Discriminator {\n")

// Construct a case for each possible variant.
for _, mapping := range tt.DiscriminatorMappings {
fmt.Fprintf(f, "\tcase %s:\n", mapping.EnumConstant)

// For objects, unmarshal into the corresponding struct. For simple types, unmarshal into a temporary struct, then grab the value from it.
if isSimpleType(mapping.ObjectType) {
fmt.Fprintf(f, "\t\tvar val %s\n", mapping.ConcreteType)
fmt.Fprintf(f, "\t\tif err := json.Unmarshal(peek.Value, &val); err != nil {\n")
fmt.Fprintf(f, "\t\t\treturn err\n")
fmt.Fprintf(f, "\t\t}\n")
fmt.Fprintf(f, "\tv.%s = val\n", tt.VariantField)
} else {
fmt.Fprintf(f, "\t\tvar val %s\n", mapping.ConcreteType)
fmt.Fprintf(f, "\t\tif err := json.Unmarshal(peek.Value, &val); err != nil {\n")
fmt.Fprintf(f, "\t\t\treturn err\n")
fmt.Fprintf(f, "\t\t}\n")
fmt.Fprintf(f, "\tv.%s = val\n", tt.VariantField)
}
}
fmt.Fprintf(f, "\tdefault:\n")
fmt.Fprintf(f, "\t\treturn fmt.Errorf(\"unknown %s discriminator value for %s: %%v\", peek.Discriminator)\n", tt.Name, tt.DiscriminatorKey)
fmt.Fprintf(f, "\t}\n")
fmt.Fprintf(f, "\tv.%s = peek.Discriminator\n", tt.DiscriminatorField)
fmt.Fprintf(f, "\treturn nil\n")
fmt.Fprint(f, "}\n")
}
if tt.VariantInterface != "" {
fmt.Fprintf(f, "\n")
fmt.Fprintf(f, "func (%s) is%s() {}\n", tt.Name, tt.VariantInterface)
}
fmt.Fprint(f, "\n")
}

Expand Down Expand Up @@ -381,7 +450,7 @@ func writeTypes(f *os.File, typeCollection []TypeTemplate, typeValidationCollect
// populateTypeTemplates populates the template of a type definition for the given schema.
// The additional parameter is only used as a suffix for the type name.
// This is mostly for oneOf types.
func populateTypeTemplates(name string, s *openapi3.Schema, enumFieldName string) ([]TypeTemplate, []EnumTemplate) {
func populateTypeTemplates(name string, s *openapi3.Schema, enumFieldName string, variantInterface string) ([]TypeTemplate, []EnumTemplate) {
typeName := name

// Type name will change for each enum type
Expand All @@ -402,6 +471,8 @@ func populateTypeTemplates(name string, s *openapi3.Schema, enumFieldName string
s.Type = &openapi3.Types{"string"}
}

typeTpl.VariantInterface = variantInterface

switch ot := getObjectType(s); ot {
case "string_enum":
enums, tt, et := createStringEnum(s, collectEnumStringTypes, name, typeName)
Expand All @@ -411,28 +482,29 @@ func populateTypeTemplates(name string, s *openapi3.Schema, enumFieldName string
case "string", "*bool", "int", "int8", "int16", "int32", "int64", "uint", "uint8",
"uint16", "uint32", "uint64", "uintptr", "float32", "float64":
typeTpl.Description = formatTypeDescription(typeName, s)
typeTpl.Type = ot
typeTpl.Type = strings.TrimPrefix(ot, "*")
typeTpl.Name = typeName
case "array":
typeTpl.Description = formatTypeDescription(typeName, s)
typeTpl.Type = fmt.Sprintf("[]%s", s.Items.Value.Type)
typeTpl.Name = typeName
case "object":
typeTpl = createTypeObject(s, name, typeName, formatTypeDescription(typeName, s))
typeTpl.VariantInterface = variantInterface

// Iterate over the properties and append the types, if we need to.
properties := sortedKeys(s.Properties)
for _, k := range properties {
v := s.Properties[k]
if isLocalEnum(v) {
tt, et := populateTypeTemplates(fmt.Sprintf("%s%s", name, strcase.ToCamel(k)), v.Value, "")
tt, et := populateTypeTemplates(fmt.Sprintf("%s%s", name, strcase.ToCamel(k)), v.Value, "", "")
types = append(types, tt...)
enumTypes = append(enumTypes, et...)
}

// TODO: So far this code is never hit with the current openapi spec
if isLocalObject(v) {
tt, et := populateTypeTemplates(fmt.Sprintf("%s%s", name, strcase.ToCamel(k)), v.Value, "")
tt, et := populateTypeTemplates(fmt.Sprintf("%s%s", name, strcase.ToCamel(k)), v.Value, "", "")
types = append(types, tt...)
enumTypes = append(enumTypes, et...)
}
Expand Down Expand Up @@ -624,6 +696,10 @@ func createStringEnum(s *openapi3.Schema, stringEnums map[string][]string, name,
// Probably not the best approach, but will leave them this way until I come up with
// a more idiomatic solution. Keep an eye out on this one to refine.
func createAllOf(s *openapi3.Schema, stringEnums map[string][]string, name, typeName string) []TypeTemplate {

// if typeName == "VpcFirewallRuleTarget" {
// panic("sad")
// }
typeTpls := make([]TypeTemplate, 0)

// Make sure we don't redeclare the enum type.
Expand All @@ -650,54 +726,49 @@ func createAllOf(s *openapi3.Schema, stringEnums map[string][]string, name, type

func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []EnumTemplate) {
var parsedProperties []string
var properties []string
var genericTypes []string
enumTpls := make([]EnumTemplate, 0)
typeTpls := make([]TypeTemplate, 0)
fields := make([]TypeFields, 0)
for _, v := range s.OneOf {
// Iterate over all the schema components in the spec and write the types.
keys := sortedKeys(v.Value.Properties)

for _, prop := range keys {
p := v.Value.Properties[prop]
// We want to collect all the unique properties to create our global oneOf type.
propertyType := convertToValidGoType(prop, typeName, p)
properties = append(properties, prop+"="+propertyType)
discriminator := ""
propertyToVariants := map[string]map[string]struct{}{}
propertyToObjectTypes := map[string]string{}
for _, v := range s.OneOf {
for propName, prop := range v.Value.Properties {
if len(prop.Value.Enum) == 1 {
discriminator = propName
}
if _, ok := propertyToVariants[propName]; !ok {
propertyToVariants[propName] = map[string]struct{}{}
}
goType := convertToValidGoType(propName, typeName, prop)
propertyToVariants[propName][goType] = struct{}{}
propertyToObjectTypes[propName] = getObjectType(prop.Value)
}
}

// When dealing with oneOf sometimes property types will not be the same, we want to
// catch these to set them as "any" when we generate the type.
typeKeys := []string{}
// First we gather all unique properties
for _, v := range properties {
parts := strings.Split(v, "=")
key := parts[0]
if !slices.Contains(typeKeys, key) {
typeKeys = append(typeKeys, key)
variantField := ""
variantFields := []string{}
variantTypes := []string{}
for propName, variants := range propertyToVariants {
if len(variants) > 1 {
variantFields = append(variantFields, propName)
variantTypes = append(variantTypes, propName)
}
}

// For each of the properties above we gather all possible types
// and gather all of those that are not. We will be setting those
// as a generic type
for _, k := range typeKeys {
values := []string{}
for _, v := range properties {
parts := strings.Split(v, "=")
key := parts[0]
value := parts[1]
if key == k {
values = append(values, value)
}
}

if !allItemsAreSame(values) {
genericTypes = append(genericTypes, k)
}
variantInterface := ""
if len(variantFields) == 1 && len(variantTypes) == 1 {
variantField = strcase.ToCamel(variantFields[0])
// Note: the variant interface is only used to define the tagged enum, and not of interest to users, so we make it private.
variantInterface = fmt.Sprintf("%s%s", strcase.ToLowerCamel(typeName), strcase.ToCamel(variantTypes[0]))
typeTpls = append(typeTpls, TypeTemplate{
Name: variantInterface,
Type: "interface",
})
}

discriminatorMappings := []DiscriminatorMapping{}
discriminatorToDiscriminatorType := map[string]string{}

for _, v := range s.OneOf {
// We want to iterate over the properties of the embedded object
// and find the type that is a string.
Expand All @@ -707,16 +778,33 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
keys := sortedKeys(v.Value.Properties)
for _, prop := range keys {
p := v.Value.Properties[prop]
propertyName := strcase.ToCamel(prop)
// We want to collect all the unique properties to create our global oneOf type.
propertyType := convertToValidGoType(prop, typeName, p)

if propertyType == "string" && len(p.Value.Enum) == 1 {
discriminatorToDiscriminatorType[prop] = typeName + strcase.ToCamel(prop)
discriminatorMappings = append(discriminatorMappings, DiscriminatorMapping{
EnumConstant: fmt.Sprintf("%s%s%s", typeName, propertyName, strcase.ToCamel(p.Value.Enum[0].(string))),
ConcreteType: fmt.Sprintf("%s%s", typeName, strcase.ToCamel(p.Value.Enum[0].(string))),
// ObjectType: propertyToObjectTypes[variantTypes[0]],
})
if len(variantFields) > 0 {
vp := v.Value.Properties[strings.ToLower(variantField)]
if vp != nil {
discriminatorMappings[len(discriminatorMappings)-1].ObjectType = propertyToObjectTypes[variantFields[0]]

}
}
}

// Check if we have an enum in order to use the corresponding type instead of
// "string"
if propertyType == "string" && len(p.Value.Enum) != 0 {
propertyType = typeName + strcase.ToCamel(prop)
}

propertyName := strcase.ToCamel(prop)
propertyName = strcase.ToCamel(prop)

// Avoids duplication for every enum
if !containsMatchFirstWord(parsedProperties, propertyName) {
Expand All @@ -728,8 +816,8 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
}

// We set the type of a field as "any" if every element of the oneOf property isn't the same
if slices.Contains(genericTypes, prop) {
field.Type = "any"
if slices.Contains(variantTypes, prop) {
field.Type = variantInterface
}

// Check if the field is nullable and use omitzero instead of omitempty.
Expand Down Expand Up @@ -765,7 +853,28 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E

// TODO: This is the only place that has an "additional name" at the end
// TODO: This is where the "allOf" is being detected
tt, et := populateTypeTemplates(name, v.Value, enumFieldName)
if len(variantFields) == 1 && v.Value.Properties[variantFields[0]] != nil {
variantField := variantFields[0]
variantType := getObjectType(v.Value.Properties[variantField].Value)
if isSimpleType(variantType) {
// Special case: process the variant field property separately
tt, _ := populateTypeTemplates(name, v.Value.Properties[variantField].Value, enumFieldName, variantInterface)
typeTpls = append(typeTpls, tt...)
parentTT, et := populateTypeTemplates(name, v.Value, enumFieldName, variantInterface)
enumTpls = append(enumTpls, et...)

// Only include parent types with Type suffix
for _, tt := range parentTT {
if strings.HasSuffix(tt.Name, strcase.ToCamel(discriminator)) {
typeTpls = append(typeTpls, tt)
}
}
continue
}
}

// Normal case: process the parent schema
tt, et := populateTypeTemplates(name, v.Value, enumFieldName, variantInterface)
typeTpls = append(typeTpls, tt...)
enumTpls = append(enumTpls, et...)
}
Expand All @@ -782,10 +891,15 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
// Make sure to only create structs if the oneOf is not a replacement for enums on the API spec
if len(fields) > 0 {
typeTpl := TypeTemplate{
Description: formatTypeDescription(typeName, s),
Name: typeName,
Type: "struct",
Fields: fields,
Description: formatTypeDescription(typeName, s),
Name: typeName,
Type: "struct",
Fields: fields,
DiscriminatorKey: discriminator,
DiscriminatorField: strcase.ToCamel(discriminator),
DiscriminatorType: discriminatorToDiscriminatorType[discriminator],
DiscriminatorMappings: discriminatorMappings,
VariantField: variantField,
}
typeTpls = append(typeTpls, typeTpl)
}
Expand Down Expand Up @@ -844,3 +958,9 @@ func formatTypeDescription(name string, s *openapi3.Schema) string {
}
return fmt.Sprintf("// %s is the type definition for a %s.", name, name)
}

func isSimpleType(t string) bool {
simpleTypes := []string{"string", "*bool", "int", "int8", "int16", "int32", "int64", "uint", "uint8",
"uint16", "uint32", "uint64", "uintptr", "float32", "float64"}
return slices.Contains(simpleTypes, t)
}
Loading