diff --git a/entproto/extension.go b/entproto/extension.go index c1cd0667a..2e692dd2c 100644 --- a/entproto/extension.go +++ b/entproto/extension.go @@ -58,6 +58,7 @@ type Extension struct { entc.DefaultExtension protoDir string skipGenFile bool + goPkg string } // WithProtoDir sets the directory where the generated .proto files will be written. @@ -74,6 +75,14 @@ func SkipGenFile() ExtensionOption { } } +// WithGoPkg sets the Go package to be used in the generated .proto files. +// By default, the Go package is derived from the last part of the proto package. +func WithGoPkg(pkg string) ExtensionOption { + return func(e *Extension) { + e.goPkg = pkg + } +} + // Hooks implements entc.Extension. func (e *Extension) Hooks() []gen.Hook { return []gen.Hook{e.hook()} @@ -165,7 +174,7 @@ func (e *Extension) generate(g *gen.Graph) error { return fmt.Errorf("entproto: failed generating generate.go file for %q: %w", protoFilePath, err) } toSchema := filepath.Join(toBase, "schema") - contents := protocGenerateGo(fd, toSchema) + contents := e.protocGenerateGo(fd, toSchema) if err := os.WriteFile(genGoPath, []byte(contents), 0600); err != nil { return fmt.Errorf("entproto: failed generating generate.go file for %q: %w", protoFilePath, err) } @@ -184,7 +193,7 @@ func fileExists(fpath string) bool { return true } -func protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir string) string { +func (e *Extension) protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir string) string { levelsUp := len(strings.Split(fd.GetPackage(), ".")) toProtoBase := "" for i := 0; i < levelsUp; i++ { @@ -202,6 +211,12 @@ func protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir string) string { fd.GetName(), } goGen := fmt.Sprintf("//go:generate %s", strings.Join(protocCmd, " ")) - goPkgName := extractLastFqnPart(fd.GetPackage()) + + // Use the provided Go package name if set, otherwise derive from the protobuf package + goPkgName := e.goPkg + if goPkgName == "" { + goPkgName = extractLastFqnPart(fd.GetPackage()) + } + return fmt.Sprintf("package %s\n%s\n", goPkgName, goGen) } diff --git a/entproto/extension_test.go b/entproto/extension_test.go new file mode 100644 index 000000000..67edf0a45 --- /dev/null +++ b/entproto/extension_test.go @@ -0,0 +1,77 @@ +package entproto + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWithGoPkg(t *testing.T) { + tests := []struct { + name string + goPkg string + fdPackage string + wantPkgName string + }{ + { + name: "Default behavior", + goPkg: "", + fdPackage: "example.service", + wantPkgName: "service", + }, + { + name: "Custom package", + goPkg: "custompkg", + fdPackage: "example.service", + wantPkgName: "custompkg", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test directly with the extractLastFqnPart helper function + // since we can't easily mock desc.FileDescriptor + var pkgName string + if tt.goPkg != "" { + pkgName = tt.goPkg + } else { + pkgName = extractLastFqnPart(tt.fdPackage) + } + + require.Equal(t, tt.wantPkgName, pkgName, + "Expected package name to be %q, got %q", tt.wantPkgName, pkgName) + }) + } +} + +// TestExtensionGoPkg tests that the WithGoPkg option properly sets the goPkg field +func TestExtensionGoPkg(t *testing.T) { + customPkg := "myspecialpkg" + ext, err := NewExtension(WithGoPkg(customPkg)) + require.NoError(t, err) + + // Check that the goPkg field is set + require.Equal(t, customPkg, ext.goPkg, + "Expected goPkg to be %q, got %q", customPkg, ext.goPkg) +} + +// TestExtractLastFqnPart tests the package extraction function +func TestExtractLastFqnPart(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"example.service", "service"}, + {"service", "service"}, + {"com.example.api.v1", "v1"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := extractLastFqnPart(tt.input) + require.Equal(t, tt.want, got, + "extractLastFqnPart(%q) = %q, want %q", tt.input, got, tt.want) + }) + } +} \ No newline at end of file