From 3fbff3a0efaca90445fe0357bbdab40ac0200517 Mon Sep 17 00:00:00 2001 From: Arun Saragadam Date: Wed, 12 Mar 2025 00:08:43 +0000 Subject: [PATCH 1/2] entproto: add option to specify Go package name Adds WithGoPkg extension option that allows customizing the Go package name used in the generated generate.go files. By default, the package name is still derived from the last part of the proto package, but this option enables explicit control when needed. This is useful for cases where the derived package name conflicts with Go naming conventions or when users want to align package names across generated code. --- entproto/extension.go | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/entproto/extension.go b/entproto/extension.go index c1cd0667a..2e692dd2c 100644 --- a/entproto/extension.go +++ b/entproto/extension.go @@ -58,6 +58,7 @@ type Extension struct { entc.DefaultExtension protoDir string skipGenFile bool + goPkg string } // WithProtoDir sets the directory where the generated .proto files will be written. @@ -74,6 +75,14 @@ func SkipGenFile() ExtensionOption { } } +// WithGoPkg sets the Go package to be used in the generated .proto files. +// By default, the Go package is derived from the last part of the proto package. +func WithGoPkg(pkg string) ExtensionOption { + return func(e *Extension) { + e.goPkg = pkg + } +} + // Hooks implements entc.Extension. func (e *Extension) Hooks() []gen.Hook { return []gen.Hook{e.hook()} @@ -165,7 +174,7 @@ func (e *Extension) generate(g *gen.Graph) error { return fmt.Errorf("entproto: failed generating generate.go file for %q: %w", protoFilePath, err) } toSchema := filepath.Join(toBase, "schema") - contents := protocGenerateGo(fd, toSchema) + contents := e.protocGenerateGo(fd, toSchema) if err := os.WriteFile(genGoPath, []byte(contents), 0600); err != nil { return fmt.Errorf("entproto: failed generating generate.go file for %q: %w", protoFilePath, err) } @@ -184,7 +193,7 @@ func fileExists(fpath string) bool { return true } -func protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir string) string { +func (e *Extension) protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir string) string { levelsUp := len(strings.Split(fd.GetPackage(), ".")) toProtoBase := "" for i := 0; i < levelsUp; i++ { @@ -202,6 +211,12 @@ func protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir string) string { fd.GetName(), } goGen := fmt.Sprintf("//go:generate %s", strings.Join(protocCmd, " ")) - goPkgName := extractLastFqnPart(fd.GetPackage()) + + // Use the provided Go package name if set, otherwise derive from the protobuf package + goPkgName := e.goPkg + if goPkgName == "" { + goPkgName = extractLastFqnPart(fd.GetPackage()) + } + return fmt.Sprintf("package %s\n%s\n", goPkgName, goGen) } From db0126e418afc354c553b74eb0238acb62b29ef8 Mon Sep 17 00:00:00 2001 From: Arun Saragadam Date: Wed, 12 Mar 2025 00:11:12 +0000 Subject: [PATCH 2/2] entproto: add tests for new WithGoPkg option Introduces extension_test.go with test cases that verify the WithGoPkg option correctly sets custom Go package names in generated files. Tests cover both the default behavior (deriving package name from proto package) and the custom package name scenario, ensuring backward compatibility while adding the new feature. --- entproto/extension_test.go | 77 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 entproto/extension_test.go diff --git a/entproto/extension_test.go b/entproto/extension_test.go new file mode 100644 index 000000000..67edf0a45 --- /dev/null +++ b/entproto/extension_test.go @@ -0,0 +1,77 @@ +package entproto + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWithGoPkg(t *testing.T) { + tests := []struct { + name string + goPkg string + fdPackage string + wantPkgName string + }{ + { + name: "Default behavior", + goPkg: "", + fdPackage: "example.service", + wantPkgName: "service", + }, + { + name: "Custom package", + goPkg: "custompkg", + fdPackage: "example.service", + wantPkgName: "custompkg", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test directly with the extractLastFqnPart helper function + // since we can't easily mock desc.FileDescriptor + var pkgName string + if tt.goPkg != "" { + pkgName = tt.goPkg + } else { + pkgName = extractLastFqnPart(tt.fdPackage) + } + + require.Equal(t, tt.wantPkgName, pkgName, + "Expected package name to be %q, got %q", tt.wantPkgName, pkgName) + }) + } +} + +// TestExtensionGoPkg tests that the WithGoPkg option properly sets the goPkg field +func TestExtensionGoPkg(t *testing.T) { + customPkg := "myspecialpkg" + ext, err := NewExtension(WithGoPkg(customPkg)) + require.NoError(t, err) + + // Check that the goPkg field is set + require.Equal(t, customPkg, ext.goPkg, + "Expected goPkg to be %q, got %q", customPkg, ext.goPkg) +} + +// TestExtractLastFqnPart tests the package extraction function +func TestExtractLastFqnPart(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"example.service", "service"}, + {"service", "service"}, + {"com.example.api.v1", "v1"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := extractLastFqnPart(tt.input) + require.Equal(t, tt.want, got, + "extractLastFqnPart(%q) = %q, want %q", tt.input, got, tt.want) + }) + } +} \ No newline at end of file