Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ OUTPUT_PATH := $(OUTPUT_DIR)/$(BINARY_NAME)

.PHONY: all build clean test test-verbose test-coverage test-race test-short test-clean benchmark test-auth test-crypto test-proxy test-utils

all: build
all: test build

build:
go build -o $(OUTPUT_PATH) $(CMD_DIR)
Expand Down
29 changes: 12 additions & 17 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,30 +110,30 @@ func main() {
func configureProxy(proxyConns *proxy.Conn, cfg *utils.Config) error {
ctx := context.TODO()

for _, t := range cfg.Targets {
for _, w := range cfg.Workloads {
var authManager *auth.AuthManager
var authType string

if t.Authentication != nil {
if w.Authentication != nil {
authManager = auth.NewAuthManager()
authType = t.Authentication.Type
authType = w.Authentication.Type

switch authType {
case "spiffe":
spiffeAuth := &auth.SpiffeAuthenticator{
TrustDomain: t.Authentication.Config["trust_domain"].(string),
Endpoint: t.Authentication.Config["endpoint"].(string),
TrustDomain: w.Authentication.Config["trust_domain"].(string),
Endpoint: w.Authentication.Config["endpoint"].(string),
}

if audiences, ok := t.Authentication.Config["audiences"].([]interface{}); ok {
if audiences, ok := w.Authentication.Config["audiences"].([]interface{}); ok {
for _, a := range audiences {
if audience, ok := a.(string); ok {
spiffeAuth.Audiences = append(spiffeAuth.Audiences, audience)
}
}
}

if err := spiffeAuth.Init(ctx, t.Authentication.Config); err != nil {
if err := spiffeAuth.Init(ctx, w.Authentication.Config); err != nil {
return fmt.Errorf("failed to initialize spiffe authenticator: %w", err)
}

Expand All @@ -149,11 +149,11 @@ func configureProxy(proxyConns *proxy.Conn, cfg *utils.Config) error {
metricsHandler := metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{
// Todo: do we need these many attributes?
InitialAttributes: attribute.NewSet(
attribute.String("proxy-id", t.ProxyId),
attribute.String("target", t.Target),
attribute.String("namespace", t.Namespace),
attribute.String("workload_id", w.WorkloadId),
attribute.String("namespace", w.TemporalCloud.Namespace),
attribute.String("host_port", w.TemporalCloud.HostPort),
attribute.String("auth_type", authType),
attribute.String("encryption_key", t.EncryptionKey),
attribute.String("encryption_key", w.EncryptionKey),
),
})

Expand All @@ -172,12 +172,7 @@ func configureProxy(proxyConns *proxy.Conn, cfg *utils.Config) error {
}

err := proxyConns.AddConn(proxy.AddConnInput{
ProxyId: t.ProxyId,
Target: t.Target,
TLSCertPath: t.TLS.CertFile,
TLSKeyPath: t.TLS.KeyFile,
EncryptionKeyID: t.EncryptionKey,
Namespace: t.Namespace,
Workload: &w,
AuthManager: authManager,
AuthType: authType,
MetricsHandler: metricsHandler,
Expand Down
41 changes: 28 additions & 13 deletions config.yaml.sample
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
server:
port: 7233
host: "0.0.0.0"

metrics:
port: 9090

encryption:
caching:
max_cache: 100
max_age: "10m"
max_usage: 100
targets:
- proxy_id: "<namespace>.<account>.internal"
target: "<namespace>.<account>.tmprl.cloud:7233"
tls:
cert_file: "/path/to/<namespace>.<account>/tls.crt"
key_file: "/path/to/<namespace>.<account>/tls.key"

workloads:
- workload_id: "<id>"
temporal_cloud:
namespace: "<namespace>.<account>"
host_port: "<namespace>.<account>.tmprl.cloud:7233" # endpoint when using mTLS
# host_port: "<region>.<cloud>.api.temporal.io:7233" # endpoint when using API keys
authentication: # only set either tls or api_key, not both
tls:
cert_file: "/path/to/<namespace>.<account>/tls.crt"
key_file: "/path/to/<namespace>.<account>/tls.key"
api_key: # only set either value or env, not both
value: "<api key>"
env: <env_var>
encryption_key: "<key>"
namespace: "<namespace>.<account>"
authentication:
type: "spiffe"
config:
Expand All @@ -24,13 +33,19 @@ targets:
audiences:
- "temporal_cloud_proxy"

- source: "<namespace>.<account>.internal"
target: "<namespace>.<account>.tmprl.cloud:7233"
tls:
cert_file: "/path/to/<namespace>.<account>/tls.crt"
key_file: "/path/to/<namespace>.<account>/tls.key"
- workload_id: "<id>"
temporal_cloud:
namespace: "<namespace>.<account>"
host_port: "<namespace>.<account>.tmprl.cloud:7233" # endpoint when using mTLS
# host_port: "<region>.<cloud>.api.temporal.io:7233" # endpoint when using API keys
authentication: # only set either tls or api_key, not both
tls:
cert_file: "/path/to/<namespace>.<account>/tls.crt"
key_file: "/path/to/<namespace>.<account>/tls.key"
api_key: # only set either value or env, not both
value: "<api key>"
env: <env_var>
encryption_key: "<key>"
namespace: "<namespace>.<account>"
authentication:
type: "spiffe"
config:
Expand Down
124 changes: 87 additions & 37 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,22 @@ import (
"crypto/tls"
"errors"
"fmt"
"go.temporal.io/sdk/converter"
"os"
"sync"
"temporal-sa/temporal-cloud-proxy/codec"
"temporal-sa/temporal-cloud-proxy/crypto"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"

"temporal-sa/temporal-cloud-proxy/auth"
"temporal-sa/temporal-cloud-proxy/metrics"

"go.temporal.io/sdk/converter"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"os"
"sync"
"temporal-sa/temporal-cloud-proxy/auth"
"temporal-sa/temporal-cloud-proxy/codec"
"temporal-sa/temporal-cloud-proxy/crypto"
"temporal-sa/temporal-cloud-proxy/metrics"
"temporal-sa/temporal-cloud-proxy/utils"
)

type Conn struct {
Expand Down Expand Up @@ -59,12 +57,7 @@ func createKMSClient() *kms.KMS {

// AddConnInput contains parameters for adding a new connection
type AddConnInput struct {
ProxyId string
Target string
TLSCertPath string
TLSKeyPath string
EncryptionKeyID string
Namespace string
Workload *utils.WorkloadConfig
AuthManager *auth.AuthManager
AuthType string
MetricsHandler metrics.MetricsHandler
Expand All @@ -73,26 +66,34 @@ type AddConnInput struct {

// AddConn adds a new connection to the proxy
func (mc *Conn) AddConn(input AddConnInput) error {
fmt.Printf("Adding connection id: %s target: %s\n", input.ProxyId, input.Target)
fmt.Printf("Adding connection id: %s namespace: %s hostport: %s\n",
input.Workload.WorkloadId, input.Workload.TemporalCloud.Namespace, input.Workload.TemporalCloud.HostPort)

cert, err := tls.LoadX509KeyPair(input.TLSCertPath, input.TLSKeyPath)
if err != nil {
return err
mc.mu.RLock()
_, exists := mc.namespace[input.Workload.WorkloadId]
mc.mu.RUnlock()
if exists {
return fmt.Errorf("workload-id %s already exists", input.Workload.WorkloadId)
}

if input.Workload.TemporalCloud.Authentication.ApiKey != nil && input.Workload.TemporalCloud.Authentication.TLS != nil {
return fmt.Errorf("%s: cannot have both api key and mtls authentication configured on a single workload",
input.Workload.WorkloadId)
}

//Initialize AWS KMS client
kmsClient := createKMSClient()

codecContext := map[string]string{
"namespace": input.Namespace,
"namespace": input.Workload.TemporalCloud.Namespace,
}

clientInterceptor, err := converter.NewPayloadCodecGRPCClientInterceptor(
converter.PayloadCodecGRPCClientInterceptorOptions{
Codecs: []converter.PayloadCodec{codec.NewEncryptionCodecWithCaching(
kmsClient,
codecContext,
input.EncryptionKeyID,
input.Workload.EncryptionKey,
input.MetricsHandler,
input.CryptoCachingConfig,
)},
Expand All @@ -102,21 +103,68 @@ func (mc *Conn) AddConn(input AddConnInput) error {
return err
}

tlsConfig := tls.Config{}

grpcInterceptors := []grpc.UnaryClientInterceptor{
clientInterceptor,
}

if apiKeyConfig := input.Workload.TemporalCloud.Authentication.ApiKey; apiKeyConfig != nil {
if apiKeyConfig.Value != "" && apiKeyConfig.EnvVar != "" {
// TODO proper logging
fmt.Printf("WARN - multiple values provided for api key, using value. workload-id: %s\n", input.Workload.WorkloadId)
}

apiKey := ""
if apiKeyConfig.Value != "" {
apiKey = apiKeyConfig.Value
} else if apiKeyConfig.EnvVar != "" {
apiKey = os.Getenv(apiKeyConfig.EnvVar)
}

if apiKey == "" {
return fmt.Errorf("%s: no api key provided", input.Workload.WorkloadId)
}

grpcInterceptors = append(grpcInterceptors,
func(ctx context.Context, method string, req any, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
md, ok := metadata.FromIncomingContext(ctx)

if ok {
md = md.Copy()
md.Delete("authorization")
md.Delete("temporal-namespace")

ctx = metadata.NewOutgoingContext(ctx, md)
ctx = metadata.AppendToOutgoingContext(ctx, "temporal-namespace", input.Workload.TemporalCloud.Namespace)
ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+apiKey)
}

return invoker(ctx, method, req, reply, cc, opts...)
})
} else {
cert, err := tls.LoadX509KeyPair(input.Workload.TemporalCloud.Authentication.TLS.CertFile,
input.Workload.TemporalCloud.Authentication.TLS.KeyFile)
if err != nil {
return err
}

tlsConfig.Certificates = []tls.Certificate{cert}
}

conn, err := grpc.NewClient(
input.Target,
input.Workload.TemporalCloud.HostPort,
grpc.WithTransportCredentials(credentials.NewTLS(
&tls.Config{
Certificates: []tls.Certificate{cert},
},
&tlsConfig,
)),
grpc.WithUnaryInterceptor(clientInterceptor),
grpc.WithChainUnaryInterceptor(grpcInterceptors...),
)
if err != nil {
return err
}

mc.mu.Lock()
mc.namespace[input.ProxyId] = NamespaceConn{
mc.namespace[input.Workload.WorkloadId] = NamespaceConn{
conn: conn,
authManager: input.AuthManager,
authType: input.AuthType,
Expand All @@ -137,8 +185,10 @@ func (mc *Conn) CloseAll() error {
if err := namespace.conn.Close(); err != nil {
errs = append(errs, err)
}
if err := namespace.authManager.Close(); err != nil {
errs = append(errs, err)
if namespace.authManager != nil {
if err := namespace.authManager.Close(); err != nil {
errs = append(errs, err)
}
}
}

Expand All @@ -152,21 +202,21 @@ func (mc *Conn) Invoke(ctx context.Context, method string, args interface{}, rep
return status.Errorf(codes.InvalidArgument, "unable to read metadata")
}

proxyId := md.Get("proxy-id")
workloadId := md.Get("workload-id")

if len(proxyId) <= 0 {
return status.Error(codes.InvalidArgument, "metadata missing proxy-id")
if len(workloadId) <= 0 {
return status.Error(codes.InvalidArgument, "metadata missing workload-id")
}
if len(proxyId) != 1 {
return status.Error(codes.InvalidArgument, "metadata contains multiple proxy-id entries")
if len(workloadId) != 1 {
return status.Error(codes.InvalidArgument, "metadata contains multiple workload-id entries")
}

mc.mu.RLock()
namespace, exists := mc.namespace[proxyId[0]]
namespace, exists := mc.namespace[workloadId[0]]
mc.mu.RUnlock()

if !exists {
return status.Errorf(codes.InvalidArgument, "invalid proxy-id: %s", proxyId[0])
return status.Errorf(codes.InvalidArgument, "invalid workload-id: %s", workloadId[0])
}

if namespace.authManager != nil {
Expand Down
Loading