Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
tclp

.idea/
config.yaml
.idea/
73 changes: 64 additions & 9 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@ import (
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"strconv"
"syscall"
"temporal-sa/temporal-cloud-proxy/auth"
"temporal-sa/temporal-cloud-proxy/crypto"
"temporal-sa/temporal-cloud-proxy/metrics"
"temporal-sa/temporal-cloud-proxy/proxy"
"temporal-sa/temporal-cloud-proxy/utils"
"time"

"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/urfave/cli/v2"
"go.opentelemetry.io/otel/attribute"
"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/sdk/client"
"google.golang.org/grpc"
Expand Down Expand Up @@ -61,12 +69,32 @@ func main() {
grpcServer := grpc.NewServer()
workflowservice.RegisterWorkflowServiceServer(grpcServer, handler)

// Initialize metrics
metrics.InitPrometheus()
metricsServer := &http.Server{Addr: ":" + strconv.Itoa(cfg.Metrics.Port)}
http.Handle(metrics.DefaultPrometheusPath, promhttp.Handler())
go func() {
fmt.Printf("Metrics is exposed at %s:%d%s\n", cfg.Server.Host, cfg.Metrics.Port, metrics.DefaultPrometheusPath)
if err := metricsServer.ListenAndServe(); err != http.ErrServerClosed {
log.Printf("metrics server error: %v", err)
}
}()

c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
fmt.Println("\nShutting down gracefully...")
grpcServer.GracefulStop()
os.Exit(0)
}()

lis, err := net.Listen("tcp", cfg.Server.Host+":"+strconv.Itoa(cfg.Server.Port))
if err != nil {
return err
}

fmt.Printf("listening on %s:%d\n", cfg.Server.Host, cfg.Server.Port)
fmt.Printf("Proxy is listening on %s:%d\n", cfg.Server.Host, cfg.Server.Port)

err = grpcServer.Serve(lis)

Expand Down Expand Up @@ -118,15 +146,42 @@ func configureProxy(proxyConns *proxy.Conn, cfg *utils.Config) error {
}
}

metricsHandler := metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{
// Todo: do we need these many attributes?
InitialAttributes: attribute.NewSet(
attribute.String("source", t.Source),
attribute.String("target", t.Target),
attribute.String("namespace", t.Namespace),
attribute.String("auth_type", authType),
attribute.String("encryption_key", t.EncryptionKey),
),
})

// Parse global caching config
var cachingConfig *crypto.CachingConfig
if cfg.Encryption.Caching.MaxCache > 0 || cfg.Encryption.Caching.MaxAge != "" || cfg.Encryption.Caching.MaxUsage > 0 {
cachingConfig = &crypto.CachingConfig{
MaxCache: cfg.Encryption.Caching.MaxCache,
MaxMessagesUsed: cfg.Encryption.Caching.MaxUsage,
}
if cfg.Encryption.Caching.MaxAge != "" {
if duration, err := time.ParseDuration(cfg.Encryption.Caching.MaxAge); err == nil {
cachingConfig.MaxAge = duration
}
}
}

err := proxyConns.AddConn(proxy.AddConnInput{
Source: t.Source,
Target: t.Target,
TLSCertPath: t.TLS.CertFile,
TLSKeyPath: t.TLS.KeyFile,
EncryptionKeyID: t.EncryptionKey,
Namespace: t.Namespace,
AuthManager: authManager,
AuthType: authType,
Source: t.Source,
Target: t.Target,
TLSCertPath: t.TLS.CertFile,
TLSKeyPath: t.TLS.KeyFile,
EncryptionKeyID: t.EncryptionKey,
Namespace: t.Namespace,
AuthManager: authManager,
AuthType: authType,
MetricsHandler: metricsHandler,
CryptoCachingConfig: cachingConfig,
})

if err != nil {
Expand Down
56 changes: 45 additions & 11 deletions codec/encryption_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"time"

"temporal-sa/temporal-cloud-proxy/crypto"
"temporal-sa/temporal-cloud-proxy/metrics"

"github.com/aws/aws-sdk-go/service/kms"

commonpb "go.temporal.io/api/common/v1"
"go.temporal.io/sdk/client"
"go.temporal.io/sdk/converter"
)

Expand All @@ -29,13 +31,29 @@ const (

// Codec implements PayloadCodec using the crypto package's cached material manager.
type Codec struct {
KeyID string
Cipher *crypto.Cipher
CodecContext map[string]string
KeyID string
Cipher *crypto.Cipher
CodecContext map[string]string
MetricsHandler client.MetricsHandler
}

// NewEncryptionCodec creates a new encryption codec with the specified key ID, AWS KMS client, and codec context.
func NewEncryptionCodec(kmsClient *kms.KMS, codecContext map[string]string, encryptionKeyID string) converter.PayloadCodec {
// NewEncryptionCodecWithCaching creates a new encryption codec with configurable caching.
func NewEncryptionCodecWithCaching(
kmsClient *kms.KMS,
codecContext map[string]string,
encryptionKeyID string,
metricsHandler client.MetricsHandler,
cachingConfig *crypto.CachingConfig,
) converter.PayloadCodec {
// Set default caching config if not provided
if cachingConfig == nil {
cachingConfig = &crypto.CachingConfig{
MaxCache: 100,
MaxAge: 5 * time.Minute,
MaxMessagesUsed: 100,
}
}

// Create AWS KMS provider
awsProvider := crypto.NewAWSKMSProvider(kmsClient, crypto.KMSOptions{
KeyID: encryptionKeyID,
Expand All @@ -45,18 +63,18 @@ func NewEncryptionCodec(kmsClient *kms.KMS, codecContext map[string]string, encr
// Create caching materials manager
cachingMM, _ := crypto.NewCachingMaterialsManager(
awsProvider,
100, // maxCache
5*time.Minute, // maxAge
1000, // maxMessagesUsed
*cachingConfig,
metricsHandler,
)

// Create cipher with caching materials manager
cipher := crypto.NewCipher(cachingMM)

return &Codec{
KeyID: encryptionKeyID,
Cipher: cipher,
CodecContext: codecContext,
KeyID: encryptionKeyID,
Cipher: cipher,
CodecContext: codecContext,
MetricsHandler: metricsHandler,
}
}

Expand All @@ -77,10 +95,14 @@ func (e *Codec) createCryptoContext(purpose, encryptionKeyID string, codecContex

// Encode implements converter.PayloadCodec.Encode.
func (e *Codec) Encode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
start := time.Now()
e.MetricsHandler.Counter(metrics.EncryptRequests).Inc(1)

result := make([]*commonpb.Payload, len(payloads))
for i, p := range payloads {
origBytes, err := p.Marshal()
if err != nil {
e.MetricsHandler.Counter(metrics.EncryptErrors).Inc(1)
return payloads, err
}

Expand All @@ -98,6 +120,7 @@ func (e *Codec) Encode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error

ciphertext, encryptedKey, err := e.Cipher.Encrypt(context.Background(), input)
if err != nil {
e.MetricsHandler.Counter(metrics.EncryptErrors).Inc(1)
return payloads, err
}

Expand All @@ -111,11 +134,16 @@ func (e *Codec) Encode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error
}
}

e.MetricsHandler.Counter(metrics.EncryptSuccess).Inc(1)
e.MetricsHandler.Timer(metrics.EncryptLatency).Record(time.Since(start))
return result, nil
}

// Decode implements converter.PayloadCodec.Decode.
func (e *Codec) Decode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
start := time.Now()
e.MetricsHandler.Counter(metrics.DecryptRequests).Inc(1)

result := make([]*commonpb.Payload, len(payloads))
for i, p := range payloads {
// Only if it's encrypted
Expand All @@ -126,6 +154,7 @@ func (e *Codec) Decode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error

keyID, ok := p.Metadata[MetadataEncryptionKeyID]
if !ok {
e.MetricsHandler.Counter(metrics.DecryptErrors).Inc(1)
return payloads, fmt.Errorf("no encryption key id")
}

Expand All @@ -135,6 +164,7 @@ func (e *Codec) Decode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error
// Get the encrypted key from metadata
encryptedKey, ok := p.Metadata[MetadataEncryptedDataKey]
if !ok {
e.MetricsHandler.Counter(metrics.DecryptErrors).Inc(1)
return payloads, fmt.Errorf("no encrypted key in payload")
}

Expand All @@ -147,15 +177,19 @@ func (e *Codec) Decode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error

decrypted, err := e.Cipher.Decrypt(context.Background(), input)
if err != nil {
e.MetricsHandler.Counter(metrics.DecryptErrors).Inc(1)
return payloads, err
}

result[i] = &commonpb.Payload{}
err = result[i].Unmarshal(decrypted)
if err != nil {
e.MetricsHandler.Counter(metrics.DecryptErrors).Inc(1)
return payloads, err
}
}

e.MetricsHandler.Counter(metrics.DecryptSuccess).Inc(1)
e.MetricsHandler.Timer(metrics.DecryptLatency).Record(time.Since(start))
return result, nil
}
19 changes: 0 additions & 19 deletions config.yaml

This file was deleted.

40 changes: 40 additions & 0 deletions config.yaml.sample
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
server:
port: 7233
host: "0.0.0.0"
metrics:
port: 9090
encryption:
caching:
max_cache: 100
max_age: "10m"
max_usage: 100
targets:
- source: "<namespace>.<account>.internal"
target: "<namespace>.<account>.tmprl.cloud:7233"
tls:
cert_file: "/path/to/<namespace>.<account>/tls.crt"
key_file: "/path/to/<namespace>.<account>/tls.key"
encryption_key: "<key>"
namespace: "<namespace>.<account>"
authentication:
type: "spiffe"
config:
trust_domain: "spiffe://example.org/"
endpoint: "unix:///tmp/spire-agent/public/api.sock"
audiences:
- "temporal_cloud_proxy"

- source: "<namespace>.<account>.internal"
target: "<namespace>.<account>.tmprl.cloud:7233"
tls:
cert_file: "/path/to/<namespace>.<account>/tls.crt"
key_file: "/path/to/<namespace>.<account>/tls.key"
encryption_key: "<key>"
namespace: "<namespace>.<account>"
authentication:
type: "spiffe"
config:
trust_domain: "spiffe://example.org/"
endpoint: "unix:///tmp/spire-agent/public/api.sock"
audiences:
- "temporal_cloud_proxy"
27 changes: 18 additions & 9 deletions crypto/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,24 @@ func setupManagers(b testing.TB) (*CachingMaterialsManager, *CachingMaterialsMan
// Create the caching materials manager
cachingMM, err := NewCachingMaterialsManager(
awsProvider,
MaxCacheSize,
KeyTTL,
MaxKeyUsage,
CachingConfig{
MaxCache: MaxCacheSize,
MaxAge: KeyTTL,
MaxMessagesUsed: MaxKeyUsage,
},
nil, // MetricsHandler
)
require.NoError(b, err, "Failed to create caching materials manager")

// Create a non-caching materials manager
noCacheMM, err := NewCachingMaterialsManager(
awsProvider,
1, // minimal cache size
0, // zero TTL forces refresh
1, // single use forces refresh
CachingConfig{
MaxCache: 1, // minimal cache size
MaxAge: 0, // zero TTL forces refresh
MaxMessagesUsed: 1, // single use forces refresh
},
nil, // MetricsHandler
)
require.NoError(b, err, "Failed to create no-cache materials manager")

Expand Down Expand Up @@ -216,9 +222,12 @@ func TestCachingBehavior(t *testing.T) {
// Create the caching materials manager with short TTL for testing
cachingMM, err := NewCachingMaterialsManager(
awsProvider,
MaxCacheSize,
100*time.Millisecond, // Very short TTL for testing
5, // Low usage count for testing
CachingConfig{
MaxCache: MaxCacheSize,
MaxAge: 100 * time.Millisecond, // Very short TTL for testing
MaxMessagesUsed: 5, // Low usage count for testing
},
nil, // MetricsHandler
)
require.NoError(t, err, "Failed to create caching materials manager")

Expand Down
Loading