From 01c76737dd2b924550a4c3199160c673c5f434ed Mon Sep 17 00:00:00 2001 From: Paul Glass Date: Tue, 3 Jun 2025 19:14:57 -0500 Subject: [PATCH 1/7] Add benchmark for name translation --- Makefile | 4 +++ interceptor/access_control_test.go | 2 +- interceptor/namespace_translator_test.go | 34 ++++++++++--------- interceptor/reflection_test.go | 42 ++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 17 deletions(-) create mode 100644 interceptor/reflection_test.go diff --git a/Makefile b/Makefile index 802f0578..77c9a353 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,7 @@ GOLANGCI_LINT ?= $(shell which golangci-lint) # Disable cgo by default. CGO_ENABLED ?= 0 TEST_ARG ?= -race -timeout=5m +BENCH_ARG ?= -benchtime=5000x ALL_SRC := $(shell find . -name "*.go") ALL_SRC += go.mod @@ -31,6 +32,9 @@ lint: @printf $(COLOR) "Running golangci-lint...\n" @$(GOLANGCI_LINT) run +bench: + @go test -run '^$$' -benchmem -bench=. ./... $(BENCH_ARG) + # Mocks clean-mocks: @find . -name '*_mock.go' -delete diff --git a/interceptor/access_control_test.go b/interceptor/access_control_test.go index 5dc86025..1e79b6f5 100644 --- a/interceptor/access_control_test.go +++ b/interceptor/access_control_test.go @@ -175,6 +175,6 @@ func testNamespaceAccessControl(t *testing.T, objCases []objCase) { } func TestNamespaceAccessControl(t *testing.T) { - testNamespaceAccessControl(t, generateNamespaceObjCases(t)) + testNamespaceAccessControl(t, generateNamespaceObjCases()) testNamespaceAccessControl(t, generateNamespaceReplicationMessages()) } diff --git a/interceptor/namespace_translator_test.go b/interceptor/namespace_translator_test.go index 864abe50..f2e475a3 100644 --- a/interceptor/namespace_translator_test.go +++ b/interceptor/namespace_translator_test.go @@ -50,7 +50,7 @@ type ( } ) -func generateNamespaceObjCases(t *testing.T) []objCase { +func generateNamespaceObjCases() []objCase { return []objCase{ { objName: "Namespace field", @@ -174,8 +174,8 @@ func generateNamespaceObjCases(t *testing.T) []objCase { return &adminservice.GetWorkflowExecutionRawHistoryV2Response{ NextPageToken: []byte("some-token"), HistoryBatches: []*common.DataBlob{ - makeHistoryEventsBlob(t, ns), - makeHistoryEventsBlob(t, ns), + makeHistoryEventsBlob(ns), + makeHistoryEventsBlob(ns), }, HistoryNodeIds: []int64{123}, } @@ -245,8 +245,8 @@ func generateNamespaceObjCases(t *testing.T) []objCase { NamespaceId: "some-ns-id", WorkflowId: "some-wf-id", RunId: "some-run-id", - Events: makeHistoryEventsBlob(t, ns), - NewRunEvents: makeHistoryEventsBlob(t, ns), + Events: makeHistoryEventsBlob(ns), + NewRunEvents: makeHistoryEventsBlob(ns), }, }, }, @@ -256,8 +256,8 @@ func generateNamespaceObjCases(t *testing.T) []objCase { NamespaceId: "some-ns-id", WorkflowId: "some-wf-id-2", RunId: "some-run-id-2", - Events: makeHistoryEventsBlob(t, ns), - NewRunEvents: makeHistoryEventsBlob(t, ns), + Events: makeHistoryEventsBlob(ns), + NewRunEvents: makeHistoryEventsBlob(ns), }, }, }, @@ -296,12 +296,12 @@ func generateNamespaceObjCases(t *testing.T) []objCase { WorkflowId: "some-wf-id", RunId: "some-run-id", EventBatches: []*common.DataBlob{ - makeHistoryEventsBlob(t, ns), - makeHistoryEventsBlob(t, ns), + makeHistoryEventsBlob(ns), + makeHistoryEventsBlob(ns), }, NewRunInfo: &replicationspb.NewRunInfo{ RunId: "some-new-run-id", - EventBatch: makeHistoryEventsBlob(t, ns), + EventBatch: makeHistoryEventsBlob(ns), }, }, }, @@ -312,12 +312,12 @@ func generateNamespaceObjCases(t *testing.T) []objCase { VersionedTransitionArtifact: &replicationspb.VersionedTransitionArtifact{ StateAttributes: nil, EventBatches: []*common.DataBlob{ - makeHistoryEventsBlob(t, ns), - makeHistoryEventsBlob(t, ns), + makeHistoryEventsBlob(ns), + makeHistoryEventsBlob(ns), }, NewRunInfo: &replicationspb.NewRunInfo{ RunId: "some-run-id", - EventBatch: makeHistoryEventsBlob(t, ns), + EventBatch: makeHistoryEventsBlob(ns), }, }, NamespaceId: "some-ns-id", @@ -540,7 +540,7 @@ func testTranslateNamespace(t *testing.T, objCases []objCase) { } } -func makeHistoryEventsBlob(t *testing.T, ns string) *common.DataBlob { +func makeHistoryEventsBlob(ns string) *common.DataBlob { evts := []*history.HistoryEvent{ { EventId: 1, @@ -568,12 +568,14 @@ func makeHistoryEventsBlob(t *testing.T, ns string) *common.DataBlob { s := serialization.NewSerializer() blob, err := s.SerializeEvents(evts, enums.ENCODING_TYPE_PROTO3) - require.NoError(t, err) + if err != nil { + panic(err) + } return blob } func TestTranslateNamespaceName(t *testing.T) { - testTranslateNamespace(t, generateNamespaceObjCases(t)) + testTranslateNamespace(t, generateNamespaceObjCases()) } func TestTranslateNamespaceReplicationMessages(t *testing.T) { diff --git a/interceptor/reflection_test.go b/interceptor/reflection_test.go new file mode 100644 index 00000000..0dfbfcba --- /dev/null +++ b/interceptor/reflection_test.go @@ -0,0 +1,42 @@ +package interceptor + +import ( + "testing" +) + +func BenchmarkVisitNamespace(b *testing.B) { + variants := []struct { + testName string + inputNSName string + mapping map[string]string + }{ + { + testName: "name changed", + inputNSName: "orig", + mapping: map[string]string{"orig": "orig.cloud"}, + }, + { + testName: "name unchanged", + inputNSName: "orig", + mapping: map[string]string{"other": "other.cloud"}, + }, + } + cases := generateNamespaceObjCases() + + for _, c := range cases { + b.Run(c.objName, func(b *testing.B) { + for _, variant := range variants { + translator := createNameTranslator(variant.mapping) + b.Run(variant.testName, func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + input := c.makeType(variant.inputNSName) + + b.StartTimer() + visitNamespace(input, translator) + } + }) + } + }) + } +} From 28acc0c9ac65d0e9b8c08cd6fbabd62e01747e3c Mon Sep 17 00:00:00 2001 From: Paul Glass Date: Wed, 4 Jun 2025 11:38:41 -0500 Subject: [PATCH 2/7] Refactor translation interceptor --- config/config.go | 24 ++++ interceptor/namespace_translator.go | 177 ------------------------- interceptor/reflection.go | 4 + interceptor/translation_interceptor.go | 128 ++++++++++++++++++ interceptor/translator.go | 37 ++++++ proxy/proxy.go | 13 +- 6 files changed, 203 insertions(+), 180 deletions(-) delete mode 100644 interceptor/namespace_translator.go create mode 100644 interceptor/translation_interceptor.go create mode 100644 interceptor/translator.go diff --git a/config/config.go b/config/config.go index 4d2d1617..9adf7618 100644 --- a/config/config.go +++ b/config/config.go @@ -274,3 +274,27 @@ func NewMockConfigProvider(config S2SProxyConfig) *MockConfigProvider { func (mc *MockConfigProvider) GetS2SProxyConfig() S2SProxyConfig { return mc.config } + +// ToMaps returns request and response mappings. +func (n NamespaceNameTranslationConfig) ToMaps(inBound bool) (map[string]string, map[string]string) { + reqMap := make(map[string]string) + respMap := make(map[string]string) + if inBound { + // For inbound listener, + // - incoming requests from remote server are modifed to match local server + // - outgoing responses to local server are modified to match remote server + for _, tr := range n.Mappings { + reqMap[tr.RemoteName] = tr.LocalName + respMap[tr.LocalName] = tr.RemoteName + } + } else { + // For outbound listener, + // - incoming requests from local server are modifed to match remote server + // - outgoing responses to remote server are modified to match local server + for _, tr := range n.Mappings { + reqMap[tr.LocalName] = tr.RemoteName + respMap[tr.RemoteName] = tr.LocalName + } + } + return reqMap, respMap +} diff --git a/interceptor/namespace_translator.go b/interceptor/namespace_translator.go deleted file mode 100644 index 27e8e438..00000000 --- a/interceptor/namespace_translator.go +++ /dev/null @@ -1,177 +0,0 @@ -package interceptor - -import ( - "context" - "fmt" - "strings" - - "github.com/temporalio/s2s-proxy/common" - "github.com/temporalio/s2s-proxy/config" - "go.temporal.io/server/common/api" - "go.temporal.io/server/common/log" - "go.temporal.io/server/common/log/tag" - "google.golang.org/grpc" -) - -type ( - NamespaceNameTranslator struct { - logger log.Logger - requestNameMapping map[string]string - responseNameMapping map[string]string - } -) - -func NewNamespaceNameTranslator( - logger log.Logger, - cfg config.ProxyConfig, - isInbound bool, - nameTranslations config.NamespaceNameTranslationConfig, -) *NamespaceNameTranslator { - requestNameMapping := map[string]string{} - responseNameMapping := map[string]string{} - for _, tr := range nameTranslations.Mappings { - if isInbound { - // For inbound listener, - // - incoming requests from remote server are modifed to match local server - // - outgoing responses to local server are modified to match remote server - requestNameMapping[tr.RemoteName] = tr.LocalName - responseNameMapping[tr.LocalName] = tr.RemoteName - } else { - // For outbound listener, - // - incoming requests from local server are modifed to match remote server - // - outgoing responses to remote server are modified to match local server - requestNameMapping[tr.LocalName] = tr.RemoteName - responseNameMapping[tr.RemoteName] = tr.LocalName - } - } - - return &NamespaceNameTranslator{ - logger: logger, - requestNameMapping: requestNameMapping, - responseNameMapping: responseNameMapping, - } -} - -var _ grpc.UnaryServerInterceptor = (*NamespaceNameTranslator)(nil).Intercept -var _ grpc.StreamServerInterceptor = (*NamespaceNameTranslator)(nil).InterceptStream - -func createNameTranslator(mapping map[string]string) matcher { - return func(name string) (string, bool) { - newName, ok := mapping[name] - return newName, ok - } -} - -func (i *NamespaceNameTranslator) Intercept( - ctx context.Context, - req any, - info *grpc.UnaryServerInfo, - handler grpc.UnaryHandler, -) (any, error) { - if common.IsRequestTranslationDisabled(ctx) { - return handler(ctx, req) - } - - if len(i.requestNameMapping) == 0 { - return handler(ctx, req) - } - - if strings.HasPrefix(info.FullMethod, api.WorkflowServicePrefix) || - strings.HasPrefix(info.FullMethod, api.AdminServicePrefix) { - - methodName := api.MethodName(info.FullMethod) - i.logger.Debug("intercepted request", tag.NewStringTag("method", methodName)) - - // Translate namespace name in request. - changed, trErr := visitNamespace(req, createNameTranslator(i.requestNameMapping)) - logTranslateNamespaceResult(i.logger, changed, trErr, methodName+"Request", req) - - resp, err := handler(ctx, req) - - // Translate namespace name in response. - changed, trErr = visitNamespace(resp, createNameTranslator(i.responseNameMapping)) - logTranslateNamespaceResult(i.logger, changed, trErr, methodName+"Response", resp) - return resp, err - } else { - return handler(ctx, req) - } -} - -func (i *NamespaceNameTranslator) InterceptStream( - srv any, - ss grpc.ServerStream, - info *grpc.StreamServerInfo, - handler grpc.StreamHandler, -) error { - i.logger.Debug("InterceptStream", - tag.NewAnyTag("method", info.FullMethod), - tag.NewAnyTag("requestMap", i.requestNameMapping), - tag.NewAnyTag("responseMap", i.responseNameMapping), - ) - err := handler(srv, newStreamTranslator( - ss, - i.logger, - i.requestNameMapping, - i.responseNameMapping, - )) - if err != nil { - i.logger.Error("grpc handler with error: %v", tag.Error(err)) - } - return err -} - -type streamTranslator struct { - grpc.ServerStream - logger log.Logger - requestTranslator matcher - responseTranslator matcher -} - -func (w *streamTranslator) RecvMsg(m any) error { - if common.IsRequestTranslationDisabled(w.Context()) { - return w.ServerStream.RecvMsg(m) - } - w.logger.Debug("Intercept RecvMsg", tag.NewAnyTag("message", m)) - changed, trErr := visitNamespace(m, w.requestTranslator) - logTranslateNamespaceResult(w.logger, changed, trErr, "RecvMsg", m) - return w.ServerStream.RecvMsg(m) -} - -func (w *streamTranslator) SendMsg(m any) error { - if common.IsRequestTranslationDisabled(w.Context()) { - return w.ServerStream.SendMsg(m) - } - w.logger.Debug("Intercept SendMsg", tag.NewStringTag("type", fmt.Sprintf("%T", m)), tag.NewAnyTag("message", m)) - changed, trErr := visitNamespace(m, w.responseTranslator) - logTranslateNamespaceResult(w.logger, changed, trErr, "SendMsg", m) - return w.ServerStream.SendMsg(m) -} - -func newStreamTranslator( - s grpc.ServerStream, - logger log.Logger, - requestMapping map[string]string, - responseMapping map[string]string, -) grpc.ServerStream { - return &streamTranslator{ - ServerStream: s, - logger: logger, - requestTranslator: createNameTranslator(requestMapping), - responseTranslator: createNameTranslator(responseMapping), - } -} - -func logTranslateNamespaceResult(logger log.Logger, changed bool, err error, methodName string, obj any) { - logger = log.With( - logger, - tag.NewStringTag("method", methodName), - tag.NewAnyTag("obj", obj), - ) - if err != nil { - logger.Error("namespace translation error", tag.Error(err)) - } else if changed { - logger.Debug("namespace translation applied") - } else { - logger.Debug("namespace translation not applied") - } -} diff --git a/interceptor/reflection.go b/interceptor/reflection.go index ad5d879c..c562de47 100644 --- a/interceptor/reflection.go +++ b/interceptor/reflection.go @@ -31,6 +31,10 @@ var ( // 2. whether or not the input name matches the defined rule(s). type matcher func(name string) (string, bool) +// visitor visits each field in obj matching the matcher. +// It returns whether anything was matched and any error it encountered. +type visitor func(obj any, match matcher) (bool, error) + // visitNamespace uses reflection to recursively visit all fields // in the given object. When it finds namespace string fields, it invokes // the provided match function. diff --git a/interceptor/translation_interceptor.go b/interceptor/translation_interceptor.go new file mode 100644 index 00000000..55cf0f4b --- /dev/null +++ b/interceptor/translation_interceptor.go @@ -0,0 +1,128 @@ +package interceptor + +import ( + "context" + "fmt" + "strings" + + "go.temporal.io/server/common/api" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + "google.golang.org/grpc" +) + +type ( + TranslationInterceptor struct { + logger log.Logger + translators []Translator + } +) + +func NewTranslationInterceptor( + logger log.Logger, + translators []Translator, +) *TranslationInterceptor { + return &TranslationInterceptor{ + logger: logger, + translators: translators, + } +} + +var _ grpc.UnaryServerInterceptor = (*TranslationInterceptor)(nil).Intercept +var _ grpc.StreamServerInterceptor = (*TranslationInterceptor)(nil).InterceptStream + +func (i *TranslationInterceptor) Intercept( + ctx context.Context, + req any, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, +) (any, error) { + if len(i.translators) > 0 && + strings.HasPrefix(info.FullMethod, api.WorkflowServicePrefix) || + strings.HasPrefix(info.FullMethod, api.AdminServicePrefix) { + + methodName := api.MethodName(info.FullMethod) + i.logger.Debug("intercepted request", tag.NewStringTag("method", methodName)) + + for _, tr := range i.translators { + changed, trErr := tr.TranslateRequest(req) + logTranslateResult(i.logger, changed, trErr, methodName+"Request", req) + } + + resp, err := handler(ctx, req) + + for _, tr := range i.translators { + changed, trErr := tr.TranslateResponse(resp) + logTranslateResult(i.logger, changed, trErr, methodName+"Response", resp) + } + + return resp, err + } else { + return handler(ctx, req) + } +} + +func (i *TranslationInterceptor) InterceptStream( + srv any, + ss grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler, +) error { + i.logger.Debug("InterceptStream", tag.NewAnyTag("method", info.FullMethod)) + err := handler(srv, newStreamTranslator(ss, i.logger, i.translators)) + if err != nil { + i.logger.Error("grpc handler with error: %v", tag.Error(err)) + } + return err +} + +type streamTranslator struct { + grpc.ServerStream + logger log.Logger + translators []Translator +} + +func (w *streamTranslator) RecvMsg(m any) error { + w.logger.Debug("Intercept RecvMsg", tag.NewAnyTag("message", m)) + for _, tr := range w.translators { + changed, trErr := tr.TranslateRequest(m) + logTranslateResult(w.logger, changed, trErr, "RecvMsg", m) + } + return w.ServerStream.RecvMsg(m) +} + +func (w *streamTranslator) SendMsg(m any) error { + w.logger.Debug("Intercept SendMsg", tag.NewStringTag("type", fmt.Sprintf("%T", m)), tag.NewAnyTag("message", m)) + for _, tr := range w.translators { + changed, trErr := tr.TranslateResponse(m) + logTranslateResult(w.logger, changed, trErr, "SendMsg", m) + } + return w.ServerStream.SendMsg(m) +} + +func newStreamTranslator( + s grpc.ServerStream, + logger log.Logger, + translators []Translator, +) grpc.ServerStream { + return &streamTranslator{ + ServerStream: s, + logger: logger, + translators: translators, + } +} + +func logTranslateResult(logger log.Logger, changed bool, err error, methodName string, obj any) { + logger = log.With( + logger, + tag.NewStringTag("method", methodName), + tag.NewAnyTag("obj", obj), + ) + if err != nil { + logger.Error("translation error", tag.Error(err)) + } else if changed { + logger.Debug("translation applied") + } else { + logger.Debug("translation not applied") + } +} diff --git a/interceptor/translator.go b/interceptor/translator.go new file mode 100644 index 00000000..6e916c00 --- /dev/null +++ b/interceptor/translator.go @@ -0,0 +1,37 @@ +package interceptor + +type ( + Translator interface { + TranslateRequest(any) (bool, error) + TranslateResponse(any) (bool, error) + } + + translatorImpl struct { + matchReq matcher + matchResp matcher + visitor visitor + } +) + +func NewNamespaceNameTranslator(reqMap, respMap map[string]string) Translator { + return &translatorImpl{ + matchReq: createNameTranslator(reqMap), + matchResp: createNameTranslator(respMap), + visitor: visitNamespace, + } +} + +func (n *translatorImpl) TranslateRequest(req any) (bool, error) { + return n.visitor(req, n.matchReq) +} + +func (n *translatorImpl) TranslateResponse(resp any) (bool, error) { + return n.visitor(resp, n.matchResp) +} + +func createNameTranslator(mapping map[string]string) matcher { + return func(name string) (string, bool) { + newName, ok := mapping[name] + return newName, ok + } +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 64f35b82..56ab60cb 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -47,12 +47,19 @@ func makeServerOptions(logger log.Logger, cfg config.ProxyConfig, isInbound bool unaryInterceptors := []grpc.UnaryServerInterceptor{} streamInterceptors := []grpc.StreamServerInterceptor{} + var translators []interceptor.Translator if len(nameTranslations.Mappings) > 0 { // NamespaceNameTranslator needs to be called before namespace access control so that // local name can be used in namespace allowed list. - translator := interceptor.NewNamespaceNameTranslator(logger, cfg, isInbound, nameTranslations) - unaryInterceptors = append(unaryInterceptors, translator.Intercept) - streamInterceptors = append(streamInterceptors, translator.InterceptStream) + translators = append(translators, + interceptor.NewNamespaceNameTranslator(nameTranslations.ToMaps(isInbound)), + ) + } + + if len(translators) > 0 { + tr := interceptor.NewTranslationInterceptor(logger, translators) + unaryInterceptors = append(unaryInterceptors, tr.Intercept) + streamInterceptors = append(streamInterceptors, tr.InterceptStream) } if isInbound && cfg.ACLPolicy != nil { From 51d13c5623483c01e2ee6c56c0458d8f50dedd8b Mon Sep 17 00:00:00 2001 From: Paul Glass Date: Wed, 4 Jun 2025 14:04:20 -0500 Subject: [PATCH 3/7] Support translating cluster name --- config/config.go | 17 ++--- interceptor/reflection.go | 143 +++++++++++++++++++++++++++----------- interceptor/translator.go | 8 +++ proxy/proxy.go | 24 +++++-- 4 files changed, 138 insertions(+), 54 deletions(-) diff --git a/config/config.go b/config/config.go index 9adf7618..e0a828eb 100644 --- a/config/config.go +++ b/config/config.go @@ -104,15 +104,16 @@ type ( } S2SProxyConfig struct { - Inbound *ProxyConfig `yaml:"inbound"` - Outbound *ProxyConfig `yaml:"outbound"` - MuxTransports []MuxTransportConfig `yaml:"mux"` - HealthCheck *HealthCheckConfig `yaml:"healthCheck"` - NamespaceNameTranslation NamespaceNameTranslationConfig `yaml:"namespaceNameTranslation"` - Metrics *MetricsConfig `yaml:"metrics"` + Inbound *ProxyConfig `yaml:"inbound"` + Outbound *ProxyConfig `yaml:"outbound"` + MuxTransports []MuxTransportConfig `yaml:"mux"` + HealthCheck *HealthCheckConfig `yaml:"healthCheck"` + NamespaceNameTranslation NameTranslationConfig `yaml:"namespaceNameTranslation"` + ClusterNameTranslation NameTranslationConfig `yaml:"clusterNameTranslation"` + Metrics *MetricsConfig `yaml:"metrics"` } - NamespaceNameTranslationConfig struct { + NameTranslationConfig struct { Mappings []NameMappingConfig `yaml:"mappings"` } @@ -276,7 +277,7 @@ func (mc *MockConfigProvider) GetS2SProxyConfig() S2SProxyConfig { } // ToMaps returns request and response mappings. -func (n NamespaceNameTranslationConfig) ToMaps(inBound bool) (map[string]string, map[string]string) { +func (n NameTranslationConfig) ToMaps(inBound bool) (map[string]string, map[string]string) { reqMap := make(map[string]string) respMap := make(map[string]string) if inBound { diff --git a/interceptor/reflection.go b/interceptor/reflection.go index c562de47..bb5c0c9a 100644 --- a/interceptor/reflection.go +++ b/interceptor/reflection.go @@ -16,6 +16,7 @@ var ( "WorkflowNamespace": true, // PollActivityTaskQueueResponse "ParentWorkflowNamespace": true, // WorkflowExecutionStartedEventAttributes } + dataBlobFieldNames = map[string]bool{ "Events": true, // HistoryTaskAttributes "NewRunEvents": true, // HistoryTaskAttributes @@ -24,6 +25,13 @@ var ( "EventsBatches": true, // HistoryTaskAttributes "HistoryBatches": true, // GetWorkflowExecutionRawHistoryV2 } + + clusterNameFields = map[string]bool{ + "ClusterName": true, // DescribeCluster, ListClusters, ReplicationTasks, GetNamespace (Clusters) + "SourceCluster": true, // HistoryDLQKey + "TargetCluster": true, // HistoryDLQKey + "ActiveClusterName": true, // GetNamespace + } ) // matcher returns 2 values: @@ -64,47 +72,58 @@ func visitNamespace(obj any, match matcher) (bool, error) { } matched = matched || ok } else if dataBlobFieldNames[fieldType.Name] { - switch evt := vwp.Interface().(type) { - case []*common.DataBlob: - newEvts, changed, err := translateDataBlobs(match, evt...) - if err != nil { - return visit.Stop, err - } - if changed { - if err := visit.Assign(vwp, reflect.ValueOf(newEvts)); err != nil { - return visit.Stop, err - } - } - matched = matched || changed - case *common.DataBlob: - newEvt, changed, err := translateOneDataBlob(match, evt) - if err != nil { - return visit.Stop, err - } - if changed { - if err := visit.Assign(vwp, reflect.ValueOf(newEvt)); err != nil { - return visit.Stop, err - } - } - matched = matched || changed - default: - return visit.Continue, nil + changed, err := visitDataBlobs(vwp, match, visitNamespace) + if err != nil { + return visit.Stop, err } + matched = matched || changed + return visit.Continue, nil } else if namespaceFieldNames[fieldType.Name] { - name, ok := vwp.Interface().(string) - if !ok { - return visit.Continue, nil + changed, err := visitStringField(vwp, match) + if err != nil { + return visit.Stop, err } - newName, ok := match(name) - if !ok { - return visit.Continue, nil + matched = matched || changed + return visit.Continue, nil + } + + return visit.Continue, nil + }) + return matched, err +} + +// visitClusterName uses reflection to recursively visit all fields +// in the given object. When it finds matching string fields, it invokes +// the provided match function. +func visitClusterName(obj any, match matcher) (bool, error) { + var matched bool + + // The visitor function can return Skip, Stop, or Continue to control recursion. + err := visit.Values(obj, func(vwp visit.ValueWithParent) (visit.Action, error) { + // Grab name of this struct field from the parent. + if vwp.Parent == nil || vwp.Parent.Kind() != reflect.Struct { + return visit.Continue, nil + } + fieldType := vwp.Parent.Type().Field(int(vwp.Index.Int())) + if !fieldType.IsExported() { + // Ignore unexported fields, particularly private gRPC message fields. + return visit.Skip, nil + } + + if dataBlobFieldNames[fieldType.Name] { + changed, err := visitDataBlobs(vwp, match, visitClusterName) + if err != nil { + return visit.Stop, err } - if name != newName { - if err := visit.Assign(vwp, reflect.ValueOf(newName)); err != nil { - return visit.Stop, err - } + matched = matched || changed + return visit.Continue, nil + } else if clusterNameFields[fieldType.Name] { + changed, err := visitStringField(vwp, match) + if err != nil { + return visit.Stop, err } - matched = matched || ok + matched = matched || changed + return visit.Continue, nil } return visit.Continue, nil @@ -112,11 +131,55 @@ func visitNamespace(obj any, match matcher) (bool, error) { return matched, err } -func translateOneDataBlob(match matcher, blob *common.DataBlob) (*common.DataBlob, bool, error) { +func visitDataBlobs(vwp visit.ValueWithParent, match matcher, visitor visitor) (bool, error) { + switch evt := vwp.Interface().(type) { + case []*common.DataBlob: + newEvts, changed, err := translateDataBlobs(match, visitor, evt...) + if err != nil { + return changed, err + } + if changed { + if err := visit.Assign(vwp, reflect.ValueOf(newEvts)); err != nil { + return changed, err + } + } + return changed, nil + case *common.DataBlob: + newEvt, changed, err := translateOneDataBlob(match, visitor, evt) + if err != nil { + return changed, err + } + if changed { + if err := visit.Assign(vwp, reflect.ValueOf(newEvt)); err != nil { + return changed, err + } + } + return changed, nil + default: + return false, nil + } +} + +func visitStringField(vwp visit.ValueWithParent, match matcher) (bool, error) { + name, ok := vwp.Interface().(string) + if !ok { + return false, nil + } + newName, ok := match(name) + if !ok || name == newName { + return false, nil + } + if err := visit.Assign(vwp, reflect.ValueOf(newName)); err != nil { + return false, err + } + return true, nil +} + +func translateOneDataBlob(match matcher, visit visitor, blob *common.DataBlob) (*common.DataBlob, bool, error) { if blob == nil || len(blob.Data) == 0 { return blob, false, nil } - blobs, changed, err := translateDataBlobs(match, blob) + blobs, changed, err := translateDataBlobs(match, visit, blob) if err != nil { return nil, false, err } @@ -126,7 +189,7 @@ func translateOneDataBlob(match matcher, blob *common.DataBlob) (*common.DataBlo return blobs[0], changed, err } -func translateDataBlobs(match matcher, blobs ...*common.DataBlob) ([]*common.DataBlob, bool, error) { +func translateDataBlobs(match matcher, visit visitor, blobs ...*common.DataBlob) ([]*common.DataBlob, bool, error) { if len(blobs) == 0 { return blobs, false, nil } @@ -140,7 +203,7 @@ func translateDataBlobs(match matcher, blobs ...*common.DataBlob) ([]*common.Dat return blobs, anyChanged, err } - changed, err := visitNamespace(evt, match) + changed, err := visit(evt, match) if err != nil { return blobs, anyChanged, err } diff --git a/interceptor/translator.go b/interceptor/translator.go index 6e916c00..029d5069 100644 --- a/interceptor/translator.go +++ b/interceptor/translator.go @@ -21,6 +21,14 @@ func NewNamespaceNameTranslator(reqMap, respMap map[string]string) Translator { } } +func NewClusterNameTranslator(reqMap, respMap map[string]string) Translator { + return &translatorImpl{ + matchReq: createNameTranslator(reqMap), + matchResp: createNameTranslator(respMap), + visitor: visitClusterName, + } +} + func (n *translatorImpl) TranslateRequest(req any) (bool, error) { return n.visitor(req, n.matchReq) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 56ab60cb..20f799c6 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -43,20 +43,32 @@ type ( } ) -func makeServerOptions(logger log.Logger, cfg config.ProxyConfig, isInbound bool, nameTranslations config.NamespaceNameTranslationConfig) ([]grpc.ServerOption, error) { +func makeServerOptions( + logger log.Logger, + cfg config.ProxyConfig, + isInbound bool, + namespaceTranslation config.NameTranslationConfig, + clusterTranslation config.NameTranslationConfig, +) ([]grpc.ServerOption, error) { unaryInterceptors := []grpc.UnaryServerInterceptor{} streamInterceptors := []grpc.StreamServerInterceptor{} var translators []interceptor.Translator - if len(nameTranslations.Mappings) > 0 { - // NamespaceNameTranslator needs to be called before namespace access control so that - // local name can be used in namespace allowed list. + if len(namespaceTranslation.Mappings) > 0 { + translators = append(translators, + interceptor.NewNamespaceNameTranslator(namespaceTranslation.ToMaps(isInbound)), + ) + } + + if len(clusterTranslation.Mappings) > 0 { translators = append(translators, - interceptor.NewNamespaceNameTranslator(nameTranslations.ToMaps(isInbound)), + interceptor.NewClusterNameTranslator(clusterTranslation.ToMaps(isInbound)), ) } if len(translators) > 0 { + // Translation needs to be called before namespace access control so that + // local name can be used in namespace allowed list. tr := interceptor.NewTranslationInterceptor(logger, translators) unaryInterceptors = append(unaryInterceptors, tr.Intercept) streamInterceptors = append(streamInterceptors, tr.InterceptStream) @@ -92,7 +104,7 @@ func (ps *ProxyServer) startServer( opts := ps.opts logger := ps.logger - serverOpts, err := makeServerOptions(logger, cfg, opts.IsInbound, opts.Config.NamespaceNameTranslation) + serverOpts, err := makeServerOptions(logger, cfg, opts.IsInbound, opts.Config.NamespaceNameTranslation, opts.Config.ClusterNameTranslation) if err != nil { return err } From 8d4490639cbf13e3115d505cef3879bf6359d812 Mon Sep 17 00:00:00 2001 From: Paul Glass Date: Wed, 4 Jun 2025 16:47:16 -0500 Subject: [PATCH 4/7] Add pprof port flag --- cmd/proxy/main.go | 26 ++++++++++++++++++++++---- config/config.go | 1 + 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 31633211..3f65298f 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "net/http" _ "net/http/pprof" "os" @@ -55,6 +56,12 @@ func buildCLIOptions() *cli.App { Usage: "Set log level(debug, info, warn, error). Default level is info", Required: false, }, + &cli.IntFlag{ + Name: config.PProfPortFlag, + Usage: "Port for the pprof HTTP server. Set to -1 to disable the pprof server.", + Required: false, + DefaultText: "6060", + }, }, Action: startProxy, }, @@ -63,19 +70,30 @@ func buildCLIOptions() *cli.App { return app } -func startProfile() { +func startProfile(c *cli.Context) { + port := 6060 + if c.IsSet(config.PProfPortFlag) { // Allow for port=0 to select a random port. + port = c.Int(config.PProfPortFlag) + } + if port < 0 { + return // pprof server disabled + } + if port > 65535 { + panic(fmt.Sprintf("invalid pprof port number %d", port)) + } + + address := fmt.Sprintf("localhost:%d", port) go func() { - if err := http.ListenAndServe("localhost:6060", nil); err != nil { + if err := http.ListenAndServe(address, nil); err != nil { panic(err) } }() - } func startProxy(c *cli.Context) error { var proxyParams ProxyParams - startProfile() + startProfile(c) var logCfg log.Config if logLevel := c.String(config.LogLevelFlag); len(logLevel) != 0 { diff --git a/config/config.go b/config/config.go index e0a828eb..510ee07e 100644 --- a/config/config.go +++ b/config/config.go @@ -14,6 +14,7 @@ import ( const ( ConfigPathFlag = "config" LogLevelFlag = "level" + PProfPortFlag = "pprof-port" ) type TransportType string From 294e106f42795951c70e1dbf52e2059c02214989 Mon Sep 17 00:00:00 2001 From: Paul Glass Date: Mon, 9 Jun 2025 10:47:24 -0500 Subject: [PATCH 5/7] fix, use "matched" instead of "changed" bc it is used for ACLs" --- interceptor/reflection.go | 52 +++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/interceptor/reflection.go b/interceptor/reflection.go index bb5c0c9a..9395a823 100644 --- a/interceptor/reflection.go +++ b/interceptor/reflection.go @@ -134,27 +134,27 @@ func visitClusterName(obj any, match matcher) (bool, error) { func visitDataBlobs(vwp visit.ValueWithParent, match matcher, visitor visitor) (bool, error) { switch evt := vwp.Interface().(type) { case []*common.DataBlob: - newEvts, changed, err := translateDataBlobs(match, visitor, evt...) + newEvts, matched, err := translateDataBlobs(match, visitor, evt...) if err != nil { - return changed, err + return matched, err } - if changed { + if matched { if err := visit.Assign(vwp, reflect.ValueOf(newEvts)); err != nil { - return changed, err + return matched, err } } - return changed, nil + return matched, nil case *common.DataBlob: - newEvt, changed, err := translateOneDataBlob(match, visitor, evt) + newEvt, matched, err := translateOneDataBlob(match, visitor, evt) if err != nil { - return changed, err + return matched, err } - if changed { + if matched { if err := visit.Assign(vwp, reflect.ValueOf(newEvt)); err != nil { - return changed, err + return matched, err } } - return changed, nil + return matched, nil default: return false, nil } @@ -165,28 +165,28 @@ func visitStringField(vwp visit.ValueWithParent, match matcher) (bool, error) { if !ok { return false, nil } - newName, ok := match(name) - if !ok || name == newName { - return false, nil + newName, matched := match(name) + if !matched || name == newName { + return matched, nil } if err := visit.Assign(vwp, reflect.ValueOf(newName)); err != nil { - return false, err + return matched, err } - return true, nil + return matched, nil } func translateOneDataBlob(match matcher, visit visitor, blob *common.DataBlob) (*common.DataBlob, bool, error) { if blob == nil || len(blob.Data) == 0 { return blob, false, nil } - blobs, changed, err := translateDataBlobs(match, visit, blob) + blobs, matched, err := translateDataBlobs(match, visit, blob) if err != nil { - return nil, false, err + return nil, matched, err } if len(blobs) != 1 { - return nil, false, fmt.Errorf("failed to translate single data blob") + return nil, matched, fmt.Errorf("failed to translate single data blob") } - return blobs[0], changed, err + return blobs[0], matched, err } func translateDataBlobs(match matcher, visit visitor, blobs ...*common.DataBlob) ([]*common.DataBlob, bool, error) { @@ -196,25 +196,25 @@ func translateDataBlobs(match matcher, visit visitor, blobs ...*common.DataBlob) s := serialization.NewSerializer() - var anyChanged bool + var anyMatched bool for i, blob := range blobs { evt, err := s.DeserializeEvents(blob) if err != nil { - return blobs, anyChanged, err + return blobs, anyMatched, err } - changed, err := visit(evt, match) + matched, err := visit(evt, match) if err != nil { - return blobs, anyChanged, err + return blobs, anyMatched, err } - anyChanged = anyChanged || changed + anyMatched = anyMatched || matched newBlob, err := s.SerializeEvents(evt, blob.EncodingType) if err != nil { - return blobs, anyChanged, err + return blobs, anyMatched, err } blobs[i] = newBlob } - return blobs, anyChanged, nil + return blobs, anyMatched, nil } From 5d3807b3d456e0075f2fc3c9b829a57539112c42 Mon Sep 17 00:00:00 2001 From: Paul Glass Date: Mon, 9 Jun 2025 11:09:24 -0500 Subject: [PATCH 6/7] unit test --- interceptor/access_control_test.go | 2 +- interceptor/namespace_translator_test.go | 74 ++++++++++++------------ 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/interceptor/access_control_test.go b/interceptor/access_control_test.go index 1e79b6f5..c3ddcb5c 100644 --- a/interceptor/access_control_test.go +++ b/interceptor/access_control_test.go @@ -160,7 +160,7 @@ func testNamespaceAccessControl(t *testing.T, objCases []objCase) { require.ErrorContains(t, err, c.expError) } else { require.NoError(t, err) - if c.containsNamespace { + if c.containsObjName { require.Equal(t, ts.expAllowed, allowed) } else { require.True(t, allowed) diff --git a/interceptor/namespace_translator_test.go b/interceptor/namespace_translator_test.go index f2e475a3..65674280 100644 --- a/interceptor/namespace_translator_test.go +++ b/interceptor/namespace_translator_test.go @@ -43,32 +43,32 @@ type ( } objCase struct { - objName string - containsNamespace bool - makeType func(ns string) any - expError string + objName string + containsObjName bool + makeType func(name string) any + expError string } ) func generateNamespaceObjCases() []objCase { return []objCase{ { - objName: "Namespace field", - containsNamespace: true, + objName: "Namespace field", + containsObjName: true, makeType: func(ns string) any { return &StructWithNamespaceField{Namespace: ns} }, }, { - objName: "WorkflowNamespace field", - containsNamespace: true, + objName: "WorkflowNamespace field", + containsObjName: true, makeType: func(ns string) any { return &StructWithWorkflowNamespaceField{WorkflowNamespace: ns} }, }, { - objName: "Nested Namespace field", - containsNamespace: true, + objName: "Nested Namespace field", + containsObjName: true, makeType: func(ns string) any { return &StructWithNestedNamespaceField{ Other: "do not change", @@ -79,8 +79,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "list of structs", - containsNamespace: true, + objName: "list of structs", + containsObjName: true, makeType: func(ns string) any { return &StructWithListOfNestedNamespaceField{ Other: "do not change", @@ -93,8 +93,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "list of ptrs", - containsNamespace: true, + objName: "list of ptrs", + containsObjName: true, makeType: func(ns string) any { return &StructWithListOfNestedPtrs{ Other: "do not change", @@ -107,8 +107,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "RespondWorkflowTaskCompletedRequest", - containsNamespace: true, + objName: "RespondWorkflowTaskCompletedRequest", + containsObjName: true, makeType: func(ns string) any { return &workflowservice.RespondWorkflowTaskCompletedRequest{ TaskToken: []byte{}, @@ -138,8 +138,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "PollWorkflowTaskQueueResponse", - containsNamespace: true, + objName: "PollWorkflowTaskQueueResponse", + containsObjName: true, makeType: func(ns string) any { return &workflowservice.PollWorkflowTaskQueueResponse{ TaskToken: []byte{}, @@ -168,8 +168,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "GetWorkflowExecutionRawHistoryV2Response", - containsNamespace: true, + objName: "GetWorkflowExecutionRawHistoryV2Response", + containsObjName: true, makeType: func(ns string) any { return &adminservice.GetWorkflowExecutionRawHistoryV2Response{ NextPageToken: []byte("some-token"), @@ -182,8 +182,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "DescribeNamespaceResponse", - containsNamespace: true, + objName: "DescribeNamespaceResponse", + containsObjName: true, makeType: func(ns string) any { return workflowservice.DescribeNamespaceResponse{ NamespaceInfo: &namespace.NamespaceInfo{ @@ -194,8 +194,8 @@ func generateNamespaceObjCases() []objCase { expError: "", }, { - objName: "UpdateNamespaceResponse", - containsNamespace: true, + objName: "UpdateNamespaceResponse", + containsObjName: true, makeType: func(ns string) any { return workflowservice.UpdateNamespaceResponse{ NamespaceInfo: &namespace.NamespaceInfo{ @@ -217,8 +217,8 @@ func generateNamespaceObjCases() []objCase { expError: "", }, { - objName: "ListNamespacesResponse", - containsNamespace: true, + objName: "ListNamespacesResponse", + containsObjName: true, makeType: func(ns string) any { return &workflowservice.ListNamespacesResponse{ Namespaces: []*workflowservice.DescribeNamespaceResponse{ @@ -232,8 +232,8 @@ func generateNamespaceObjCases() []objCase { expError: "", }, { - objName: "StreamWorkflowReplicationMessagesResponse", - containsNamespace: true, + objName: "StreamWorkflowReplicationMessagesResponse", + containsObjName: true, makeType: func(ns string) any { return &adminservice.StreamWorkflowReplicationMessagesResponse{ Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ @@ -331,8 +331,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "circular pointer", - containsNamespace: true, + objName: "circular pointer", + containsObjName: true, makeType: func(ns string) any { a := &StructWithCircularPointer{ Namespace: ns, @@ -473,14 +473,14 @@ func generateNamespaceReplicationMessages() []objCase { }, }, { - objName: "full type", - makeType: makeFullType, - containsNamespace: true, + objName: "full type", + makeType: makeFullType, + containsObjName: true, }, } } -func testTranslateNamespace(t *testing.T, objCases []objCase) { +func testTranslateObjects(t *testing.T, objCases []objCase) { testcases := []struct { testName string inputNSName string @@ -525,7 +525,7 @@ func testTranslateNamespace(t *testing.T, objCases []objCase) { require.ErrorContains(t, err, c.expError) } else { require.NoError(t, err) - if c.containsNamespace { + if c.containsObjName { require.Equal(t, expOutput, input) require.Equal(t, expChanged, changed) } else { @@ -575,9 +575,9 @@ func makeHistoryEventsBlob(ns string) *common.DataBlob { } func TestTranslateNamespaceName(t *testing.T) { - testTranslateNamespace(t, generateNamespaceObjCases()) + testTranslateObjects(t, generateNamespaceObjCases()) } func TestTranslateNamespaceReplicationMessages(t *testing.T) { - testTranslateNamespace(t, generateNamespaceReplicationMessages()) + testTranslateObjects(t, generateNamespaceReplicationMessages()) } From 3d710778c5095ac2dc25c85edd5f1f215f07ce81 Mon Sep 17 00:00:00 2001 From: Paul Glass Date: Mon, 9 Jun 2025 11:19:51 -0500 Subject: [PATCH 7/7] Revert "Add pprof port flag" This reverts commit 8d4490639cbf13e3115d505cef3879bf6359d812. --- cmd/proxy/main.go | 26 ++++---------------------- config/config.go | 1 - 2 files changed, 4 insertions(+), 23 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 3f65298f..31633211 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "net/http" _ "net/http/pprof" "os" @@ -56,12 +55,6 @@ func buildCLIOptions() *cli.App { Usage: "Set log level(debug, info, warn, error). Default level is info", Required: false, }, - &cli.IntFlag{ - Name: config.PProfPortFlag, - Usage: "Port for the pprof HTTP server. Set to -1 to disable the pprof server.", - Required: false, - DefaultText: "6060", - }, }, Action: startProxy, }, @@ -70,30 +63,19 @@ func buildCLIOptions() *cli.App { return app } -func startProfile(c *cli.Context) { - port := 6060 - if c.IsSet(config.PProfPortFlag) { // Allow for port=0 to select a random port. - port = c.Int(config.PProfPortFlag) - } - if port < 0 { - return // pprof server disabled - } - if port > 65535 { - panic(fmt.Sprintf("invalid pprof port number %d", port)) - } - - address := fmt.Sprintf("localhost:%d", port) +func startProfile() { go func() { - if err := http.ListenAndServe(address, nil); err != nil { + if err := http.ListenAndServe("localhost:6060", nil); err != nil { panic(err) } }() + } func startProxy(c *cli.Context) error { var proxyParams ProxyParams - startProfile(c) + startProfile() var logCfg log.Config if logLevel := c.String(config.LogLevelFlag); len(logLevel) != 0 { diff --git a/config/config.go b/config/config.go index 510ee07e..e0a828eb 100644 --- a/config/config.go +++ b/config/config.go @@ -14,7 +14,6 @@ import ( const ( ConfigPathFlag = "config" LogLevelFlag = "level" - PProfPortFlag = "pprof-port" ) type TransportType string