Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions entproto/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type Extension struct {
entc.DefaultExtension
protoDir string
skipGenFile bool
goPkg string
}

// WithProtoDir sets the directory where the generated .proto files will be written.
Expand All @@ -74,6 +75,14 @@ func SkipGenFile() ExtensionOption {
}
}

// WithGoPkg sets the Go package to be used in the generated .proto files.
// By default, the Go package is derived from the last part of the proto package.
func WithGoPkg(pkg string) ExtensionOption {
return func(e *Extension) {
e.goPkg = pkg
}
}

// Hooks implements entc.Extension.
func (e *Extension) Hooks() []gen.Hook {
return []gen.Hook{e.hook()}
Expand Down Expand Up @@ -165,7 +174,7 @@ func (e *Extension) generate(g *gen.Graph) error {
return fmt.Errorf("entproto: failed generating generate.go file for %q: %w", protoFilePath, err)
}
toSchema := filepath.Join(toBase, "schema")
contents := protocGenerateGo(fd, toSchema)
contents := e.protocGenerateGo(fd, toSchema)
if err := os.WriteFile(genGoPath, []byte(contents), 0600); err != nil {
return fmt.Errorf("entproto: failed generating generate.go file for %q: %w", protoFilePath, err)
}
Expand All @@ -184,7 +193,7 @@ func fileExists(fpath string) bool {
return true
}

func protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir string) string {
func (e *Extension) protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir string) string {
levelsUp := len(strings.Split(fd.GetPackage(), "."))
toProtoBase := ""
for i := 0; i < levelsUp; i++ {
Expand All @@ -202,6 +211,12 @@ func protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir string) string {
fd.GetName(),
}
goGen := fmt.Sprintf("//go:generate %s", strings.Join(protocCmd, " "))
goPkgName := extractLastFqnPart(fd.GetPackage())

// Use the provided Go package name if set, otherwise derive from the protobuf package
goPkgName := e.goPkg
if goPkgName == "" {
goPkgName = extractLastFqnPart(fd.GetPackage())
}

return fmt.Sprintf("package %s\n%s\n", goPkgName, goGen)
}
77 changes: 77 additions & 0 deletions entproto/extension_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package entproto

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestWithGoPkg(t *testing.T) {
tests := []struct {
name string
goPkg string
fdPackage string
wantPkgName string
}{
{
name: "Default behavior",
goPkg: "",
fdPackage: "example.service",
wantPkgName: "service",
},
{
name: "Custom package",
goPkg: "custompkg",
fdPackage: "example.service",
wantPkgName: "custompkg",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test directly with the extractLastFqnPart helper function
// since we can't easily mock desc.FileDescriptor
var pkgName string
if tt.goPkg != "" {
pkgName = tt.goPkg
} else {
pkgName = extractLastFqnPart(tt.fdPackage)
}

require.Equal(t, tt.wantPkgName, pkgName,
"Expected package name to be %q, got %q", tt.wantPkgName, pkgName)
})
}
}

// TestExtensionGoPkg tests that the WithGoPkg option properly sets the goPkg field
func TestExtensionGoPkg(t *testing.T) {
customPkg := "myspecialpkg"
ext, err := NewExtension(WithGoPkg(customPkg))
require.NoError(t, err)

// Check that the goPkg field is set
require.Equal(t, customPkg, ext.goPkg,
"Expected goPkg to be %q, got %q", customPkg, ext.goPkg)
}

// TestExtractLastFqnPart tests the package extraction function
func TestExtractLastFqnPart(t *testing.T) {
tests := []struct {
input string
want string
}{
{"example.service", "service"},
{"service", "service"},
{"com.example.api.v1", "v1"},
{"", ""},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := extractLastFqnPart(tt.input)
require.Equal(t, tt.want, got,
"extractLastFqnPart(%q) = %q, want %q", tt.input, got, tt.want)
})
}
}