diff --git a/codegen/gateway.go b/codegen/gateway.go index aee337c15..85566f080 100644 --- a/codegen/gateway.go +++ b/codegen/gateway.go @@ -30,7 +30,7 @@ import ( "sort" "strings" - yaml "github.com/ghodss/yaml" + "github.com/ghodss/yaml" "github.com/pkg/errors" "go.uber.org/thriftrw/compile" ) @@ -263,6 +263,8 @@ type EndpointSpec struct { ClientID string `yaml:"clientId,omitempty"` // if "httpClient", which client method to call. ClientMethod string `yaml:"clientMethod,omitempty"` + // Config additional configs for this endpoint + Config map[string]interface{} `yaml:"config, omitempty"` // The client for this endpoint if httpClient or tchannelClient ClientSpec *ClientSpec `yaml:"-"` // IsClientlessEndpoint checks if the endpoint is clientless @@ -401,6 +403,12 @@ func NewEndpointSpec( thriftInfo, yamlFile, ) } + var config map[string]interface{} + if _, ok := endpointConfigObj["config"]; !ok { + config = make(map[string]interface{}) + } else { + config = endpointConfigObj["config"].(map[string]interface{}) + } espec := &EndpointSpec{ ModuleSpec: mspec, @@ -420,6 +428,7 @@ func NewEndpointSpec( ClientID: clientID, ClientMethod: clientMethod, DefaultHeaders: h.defaultHeaders, + Config: config, } defaultMidSpecs, err := getOrderedDefaultMiddlewareSpecs( diff --git a/codegen/module.go b/codegen/module.go index c03049c30..dd952683c 100644 --- a/codegen/module.go +++ b/codegen/module.go @@ -83,6 +83,7 @@ type Options struct { EnableCustomInitialisation bool CommitChange bool QPSLevelsEnabled bool + ProxyTemplates *Template } // PostGenHook provides a way to do work after the build is generated, @@ -975,6 +976,7 @@ func (system *ModuleSystem) readInstance( YAMLFileRaw: raw, Config: config.Config, SelectiveBuilding: config.SelectiveBuilding, + ProxyTemplates: options.ProxyTemplates, }, nil } @@ -1787,6 +1789,8 @@ type ModuleInstance struct { mu sync.RWMutex // QPSLevels is map of circuit breaker name to qps level for all circuit breakers QPSLevels map[string]int + // ProxyTemplates is a collection of extra templates apart from zanzibar + ProxyTemplates *Template } func (instance *ModuleInstance) String() string { diff --git a/codegen/module_system.go b/codegen/module_system.go index 6d31760f2..583034590 100644 --- a/codegen/module_system.go +++ b/codegen/module_system.go @@ -520,8 +520,11 @@ func (g *httpClientGenerator) Generate( SidecarRouter: clientSpec.SidecarRouter, } - client, err := g.templates.ExecTemplate( + client, err := ExecuteDefaultOrProxyTemplate( "http_client.tmpl", + g.templates, + instance.ProxyTemplates, + instance.Config, clientMeta, g.packageHelper, ) @@ -774,8 +777,11 @@ func (g *tchannelClientGenerator) Generate( DeputyReqHeader: g.packageHelper.DeputyReqHeader(), } - client, err := g.templates.ExecTemplate( + client, err := ExecuteDefaultOrProxyTemplate( "tchannel_client.tmpl", + g.templates, + instance.ProxyTemplates, + instance.Config, clientMeta, g.packageHelper, ) @@ -813,8 +819,11 @@ func (g *tchannelClientGenerator) Generate( genTestServer, _ := instance.Config["genTestServer"].(bool) if genTestServer { - server, err := g.templates.ExecTemplate( + server, err := ExecuteDefaultOrProxyTemplate( "tchannel_client_test_server.tmpl", + g.templates, + instance.ProxyTemplates, + instance.Config, clientMeta, g.packageHelper, ) @@ -1075,8 +1084,11 @@ func (g *gRPCClientGenerator) Generate( QPSLevels: clientQPSLevels, } - client, err := g.templates.ExecTemplate( + client, err := ExecuteDefaultOrProxyTemplate( "grpc_client.tmpl", + g.templates, + instance.ProxyTemplates, + instance.Config, clientMeta, g.packageHelper, ) @@ -1297,8 +1309,11 @@ func (g *EndpointGenerator) Generate( return endpointMeta[i].Spec.HandleID < endpointMeta[j].Spec.HandleID }) - endpointCollection, err := g.templates.ExecTemplate( + endpointCollection, err := ExecuteDefaultOrProxyTemplate( "endpoint_collection.tmpl", + g.templates, + instance.ProxyTemplates, + instance.Config, &EndpointCollectionMeta{ Instance: instance, EndpointMeta: endpointMeta, @@ -1357,8 +1372,11 @@ func (g *EndpointGenerator) generateEndpointFile(e *EndpointSpec, instance *Modu Instance: instance, Spec: m, } - structs, err := g.templates.ExecTemplate( + structs, err := ExecuteDefaultOrProxyTemplate( "structs.tmpl", + g.templates, + instance.ProxyTemplates, + e.Config, meta, g.packageHelper, ) @@ -1440,9 +1458,11 @@ func (g *EndpointGenerator) generateEndpointFile(e *EndpointSpec, instance *Modu f := func() (interface{}, error) { var endpoint []byte if e.EndpointType == "http" { - endpoint, err = g.templates.ExecTemplate("endpoint.tmpl", meta, g.packageHelper) + endpoint, err = ExecuteDefaultOrProxyTemplate("endpoint.tmpl", g.templates, + instance.ProxyTemplates, e.Config,meta, g.packageHelper) } else if e.EndpointType == "tchannel" { - endpoint, err = g.templates.ExecTemplate("tchannel_endpoint.tmpl", meta, g.packageHelper) + endpoint, err = ExecuteDefaultOrProxyTemplate("tchannel_endpoint.tmpl", g.templates, + instance.ProxyTemplates, e.Config, meta, g.packageHelper) } else { err = errors.Errorf("Endpoint type '%s' is not supported", e.EndpointType) } @@ -1461,7 +1481,7 @@ func (g *EndpointGenerator) generateEndpointFile(e *EndpointSpec, instance *Modu } else { tmpl = "workflow.tmpl" } - workflow, err := g.templates.ExecTemplate(tmpl, meta, g.packageHelper) + workflow, err := ExecuteDefaultOrProxyTemplate(tmpl, g.templates, instance.ProxyTemplates, e.Config, meta, g.packageHelper) if err != nil { return nil, errors.Wrap(err, "Error executing workflow template") } @@ -1542,7 +1562,8 @@ func (g *EndpointGenerator) generateEndpointTestFile( tempName = "endpoint_test_tchannel_client.tmpl" } - endpointTest, err := g.templates.ExecTemplate(tempName, meta, g.packageHelper) + endpointTest, err := ExecuteDefaultOrProxyTemplate(tempName, g.templates, instance.ProxyTemplates, + instance.Config, meta, g.packageHelper) if err != nil { return errors.Wrap(err, "Error executing endpoint test template") } @@ -1586,8 +1607,11 @@ func (generator *GatewayServiceGenerator) Generate(instance *ModuleInstance) (*B runner := parallelize.NewUnboundedRunner(workCount) f := func() (interface{}, error) { - service, err := generator.templates.ExecTemplate( + service, err := ExecuteDefaultOrProxyTemplate( "service.tmpl", + generator.templates, + instance.ProxyTemplates, + instance.Config, instance, generator.packageHelper, ) @@ -1606,8 +1630,11 @@ func (generator *GatewayServiceGenerator) Generate(instance *ModuleInstance) (*B f = func() (interface{}, error) { // generate main.go - main, err := generator.templates.ExecTemplate( + main, err := ExecuteDefaultOrProxyTemplate( "main.tmpl", + generator.templates, + instance.ProxyTemplates, + instance.Config, instance, generator.packageHelper, ) @@ -1626,8 +1653,11 @@ func (generator *GatewayServiceGenerator) Generate(instance *ModuleInstance) (*B f = func() (interface{}, error) { // generate main_test.go - mainTest, err := generator.templates.ExecTemplate( + mainTest, err := ExecuteDefaultOrProxyTemplate( "main_test.tmpl", + generator.templates, + instance.ProxyTemplates, + instance.Config, instance, generator.packageHelper, ) @@ -1760,7 +1790,7 @@ func (g *MiddlewareGenerator) generateMiddlewareFile(instance *ModuleInstance, o templateName = "middleware_tchannel.tmpl" } - bytes, err := g.templates.ExecTemplate(templateName, instance, g.packageHelper) + bytes, err := ExecuteDefaultOrProxyTemplate(templateName, g.templates, instance.ProxyTemplates, instance.Config, instance, g.packageHelper) if err != nil { return err } @@ -1796,8 +1826,11 @@ func GenerateDependencyStruct( if genCustom != "" { instance.PackageInfo.ExportType = instance.Config["customInterface"].(string) } - return template.ExecTemplate( + return ExecuteDefaultOrProxyTemplate( "dependency_struct.tmpl", + template, + instance.ProxyTemplates, + instance.Config, instance, packageHelper, ) @@ -1811,8 +1844,11 @@ func GenerateInitializer( packageHelper *PackageHelper, template *Template, ) ([]byte, error) { - return template.ExecTemplate( + return ExecuteDefaultOrProxyTemplate( "module_initializer.tmpl", + template, + instance.ProxyTemplates, + instance.Config, instance, packageHelper, ) @@ -1857,6 +1893,23 @@ func findMethod( return nil } +// ExecuteDefaultOrProxyTemplate verify and execute a default or proxy template +func ExecuteDefaultOrProxyTemplate( + defaultTemplateName string, + defaultTemplates *Template, + proxyTemplates *Template, + config map[string]interface{}, + tplData interface{}, + packageHelper *PackageHelper, +) (ret []byte, rErr error) { + tmplName, templates := GetDefaultOrProxyTemplate(defaultTemplateName, defaultTemplates, proxyTemplates, config) + return templates.ExecTemplate( + tmplName, + tplData, + packageHelper, + ) +} + type sortByClientName []*ClientSpec func (c sortByClientName) Len() int { diff --git a/codegen/template.go b/codegen/template.go index bdba91b68..8e0fbed1e 100644 --- a/codegen/template.go +++ b/codegen/template.go @@ -32,6 +32,9 @@ import ( templates "github.com/uber/zanzibar/codegen/template_bundle" ) +// ProxyTemplate proxy template representation +const ProxyTemplate = "proxyTemplate" + // AssetProvider provides access to template assets type AssetProvider interface { // AssetNames returns a list of named assets available @@ -191,6 +194,23 @@ func (t *Template) ExecTemplate( return } +// GetDefaultOrProxyTemplate verify and returns overridden template if present +func GetDefaultOrProxyTemplate(defaultTemplateName string, + defaultTemplate *Template, + proxyTemplate *Template, + config map[string]interface{}) (string, *Template) { + // verify if config contains template override + if _, ok := config[ProxyTemplate]; !ok { + return defaultTemplateName, defaultTemplate + } + // verify if override template config contains given template + templateMap := config[ProxyTemplate].(map[string]string) + if _, ok := templateMap[defaultTemplateName]; !ok { + return defaultTemplateName, defaultTemplate + } + return templateMap[defaultTemplateName], proxyTemplate +} + func openFileOrCreate(file string) (*os.File, error) { if _, err := os.Stat(file); os.IsNotExist(err) { if err := os.MkdirAll(filepath.Dir(file), os.ModePerm); err != nil {