diff --git a/cmd/rekor-server/app/root.go b/cmd/rekor-server/app/root.go index c2d68ade9..8a02510fd 100644 --- a/cmd/rekor-server/app/root.go +++ b/cmd/rekor-server/app/root.go @@ -27,6 +27,7 @@ import ( "github.com/go-chi/chi/v5/middleware" "github.com/sigstore/rekor/pkg/api" "github.com/sigstore/rekor/pkg/log" + "github.com/sigstore/rekor/pkg/trillianclient" cose "github.com/sigstore/rekor/pkg/types/cose/v0.0.1" intoto001 "github.com/sigstore/rekor/pkg/types/intoto/v0.0.1" intoto002 "github.com/sigstore/rekor/pkg/types/intoto/v0.0.2" @@ -93,6 +94,9 @@ func init() { rootCmd.PersistentFlags().Uint("trillian_log_server.tlog_id", 0, "Trillian tree id") rootCmd.PersistentFlags().String("trillian_log_server.sharding_config", "", "path to config file for inactive shards, in JSON or YAML") rootCmd.PersistentFlags().String("trillian_log_server.grpc_default_service_config", "", "JSON string used to configure gRPC clients for communicating with Trillian") + rootCmd.PersistentFlags().Duration("trillian_log_server.init_latest_root_timeout", trillianclient.DefaultInitLatestRootTimeout, "timeout for fetching the latest root during client initialization") + rootCmd.PersistentFlags().Duration("trillian_log_server.updater_wait_timeout", trillianclient.DefaultUpdaterWaitTimeout, "timeout for STH updater polling wait operations") + rootCmd.PersistentFlags().Bool("trillian_log_server.cache_sth", false, "enable cached STH client with background root updates (experimental)") rootCmd.PersistentFlags().Uint("publish_frequency", 5, "how often to publish a new checkpoint, in minutes") diff --git a/docker-compose.yml b/docker-compose.yml index b53b3fe34..a6622c612 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -116,6 +116,7 @@ services: "--attestation_storage_bucket=file:///var/run/attestations", "--search_index.storage_provider=mysql", "--search_index.mysql.dsn=test:zaphod@tcp(mysql:3306)/test", + # "--trillian_log_server.cache_sth=true", # Uncomment this for production logging # "--log_type=prod", ] diff --git a/pkg/api/api.go b/pkg/api/api.go index 53cefa7d0..4b31be766 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -109,7 +109,22 @@ func NewAPI(treeID int64) (*API, error) { inactiveGRPCConfigs[r.TreeID] = *r.GRPCConfig } } - tcm := trillianclient.NewClientManager(inactiveGRPCConfigs, defaultGRPCConfig) + + // Inactive shards are frozen — their trees will never advance. + frozenTreeIDs := make(map[int64]bool) + for _, r := range ranges.GetInactive() { + frozenTreeIDs[r.TreeID] = true + } + + // Read timeout configuration from command line flags/config + clientConfig := trillianclient.Config{ + CacheSTH: viper.GetBool("trillian_log_server.cache_sth"), + InitLatestRootTimeout: viper.GetDuration("trillian_log_server.init_latest_root_timeout"), + UpdaterWaitTimeout: viper.GetDuration("trillian_log_server.updater_wait_timeout"), + FrozenTreeIDs: frozenTreeIDs, + } + + tcm := trillianclient.NewClientManager(inactiveGRPCConfigs, defaultGRPCConfig, clientConfig) roots, err := ranges.CompleteInitialization(ctx, tcm) if err != nil { diff --git a/pkg/sharding/ranges_test.go b/pkg/sharding/ranges_test.go index dfbe6182b..fb1fa9636 100644 --- a/pkg/sharding/ranges_test.go +++ b/pkg/sharding/ranges_test.go @@ -654,19 +654,6 @@ func TestCompleteInitialization_Scenarios(t *testing.T) { SigningSchemeOrKeyPath: keyPath, } - // --- Scenario 1: Multiple Backends --- - s1, close1 := setupMockServer(t, mockCtl) - defer close1() - addr1 := s1.Addr - port1, err := strconv.Atoi(addr1[strings.LastIndex(addr1, ":")+1:]) - require.NoError(t, err) - - s2, close2 := setupMockServer(t, mockCtl) - defer close2() - addr2 := s2.Addr - port2, err := strconv.Atoi(addr2[strings.LastIndex(addr2, ":")+1:]) - require.NoError(t, err) - // --- Scenario 4: Connection Failure --- // Find an unused port for the connection failure test lisClosed, err := net.Listen("tcp", ":0") @@ -683,27 +670,40 @@ func TestCompleteInitialization_Scenarios(t *testing.T) { }{ { name: "Scenario 1: Multiple Backends", - setup: func(_ *testing.T, logRanges *LogRanges, tcm **trillianclient.ClientManager) { + setup: func(t *testing.T, logRanges *LogRanges, tcm **trillianclient.ClientManager) { // Setup two inactive shards, each pointing to a different server inactive1, _ := initializeRange(context.Background(), LogRange{TreeID: 101, SigningConfig: activeSC}) inactive2, _ := initializeRange(context.Background(), LogRange{TreeID: 102, SigningConfig: activeSC}) logRanges.inactive = Ranges{inactive1, inactive2} + // Create isolated servers for this scenario + sA, closeA := setupMockServer(t, mockCtl) + t.Cleanup(closeA) + addrA := sA.Addr + portA, err := strconv.Atoi(addrA[strings.LastIndex(addrA, ":")+1:]) + require.NoError(t, err) + + sB, closeB := setupMockServer(t, mockCtl) + t.Cleanup(closeB) + addrB := sB.Addr + portB, err := strconv.Atoi(addrB[strings.LastIndex(addrB, ":")+1:]) + require.NoError(t, err) + // Mock responses from each server root1 := &types.LogRootV1{TreeSize: 42} rootBytes1, _ := root1.MarshalBinary() - s1.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: &trillian.SignedLogRoot{LogRoot: rootBytes1}}, nil) + sA.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: &trillian.SignedLogRoot{LogRoot: rootBytes1}}, nil).MinTimes(1) root2 := &types.LogRootV1{TreeSize: 84} rootBytes2, _ := root2.MarshalBinary() - s2.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: &trillian.SignedLogRoot{LogRoot: rootBytes2}}, nil) + sB.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: &trillian.SignedLogRoot{LogRoot: rootBytes2}}, nil).MinTimes(1) // Configure client manager to route to the correct servers grpcConfigs := map[int64]trillianclient.GRPCConfig{ - 101: {Address: "localhost", Port: uint16(port1)}, - 102: {Address: "localhost", Port: uint16(port2)}, + 101: {Address: "localhost", Port: uint16(portA)}, + 102: {Address: "localhost", Port: uint16(portB)}, } - *tcm = trillianclient.NewClientManager(grpcConfigs, trillianclient.GRPCConfig{}) + *tcm = trillianclient.NewClientManager(grpcConfigs, trillianclient.GRPCConfig{}, trillianclient.DefaultConfig()) }, expectErr: false, postCondition: func(t *testing.T, logRanges *LogRanges, roots map[int64]types.LogRootV1) { @@ -718,17 +718,24 @@ func TestCompleteInitialization_Scenarios(t *testing.T) { }, { name: "Scenario 2: Fallback to Default Backend", - setup: func(_ *testing.T, logRanges *LogRanges, tcm **trillianclient.ClientManager) { + setup: func(t *testing.T, logRanges *LogRanges, tcm **trillianclient.ClientManager) { inactive, _ := initializeRange(context.Background(), LogRange{TreeID: 201, SigningConfig: activeSC}) logRanges.inactive = Ranges{inactive} + // Create a dedicated default backend for this scenario + sDef, closeDef := setupMockServer(t, mockCtl) + t.Cleanup(closeDef) + addr := sDef.Addr + port, err := strconv.Atoi(addr[strings.LastIndex(addr, ":")+1:]) + require.NoError(t, err) + root := &types.LogRootV1{TreeSize: 99} rootBytes, _ := root.MarshalBinary() - s1.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: &trillian.SignedLogRoot{LogRoot: rootBytes}}, nil) + sDef.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: &trillian.SignedLogRoot{LogRoot: rootBytes}}, nil).MinTimes(1) // No specific config for tree 201, so it should use the default - defaultConfig := trillianclient.GRPCConfig{Address: "localhost", Port: uint16(port1)} - *tcm = trillianclient.NewClientManager(map[int64]trillianclient.GRPCConfig{}, defaultConfig) + defaultConfig := trillianclient.GRPCConfig{Address: "localhost", Port: uint16(port)} + *tcm = trillianclient.NewClientManager(map[int64]trillianclient.GRPCConfig{}, defaultConfig, trillianclient.DefaultConfig()) }, expectErr: false, postCondition: func(t *testing.T, logRanges *LogRanges, roots map[int64]types.LogRootV1) { @@ -742,7 +749,9 @@ func TestCompleteInitialization_Scenarios(t *testing.T) { name: "Scenario 3: No Inactive Shards", setup: func(_ *testing.T, logRanges *LogRanges, tcm **trillianclient.ClientManager) { logRanges.inactive = Ranges{} - *tcm = trillianclient.NewClientManager(nil, trillianclient.GRPCConfig{Address: "localhost", Port: uint16(port1)}) + // No inactive shards means the client manager won't be used. + // Provide a no-op default config to satisfy constructor. + *tcm = trillianclient.NewClientManager(nil, trillianclient.GRPCConfig{Address: "localhost", Port: 0}, trillianclient.DefaultConfig()) }, expectErr: false, postCondition: func(t *testing.T, logRanges *LogRanges, roots map[int64]types.LogRootV1) { @@ -760,23 +769,30 @@ func TestCompleteInitialization_Scenarios(t *testing.T) { grpcConfigs := map[int64]trillianclient.GRPCConfig{ 401: {Address: "localhost", Port: uint16(closedAddr.Port)}, } - *tcm = trillianclient.NewClientManager(grpcConfigs, trillianclient.GRPCConfig{}) + *tcm = trillianclient.NewClientManager(grpcConfigs, trillianclient.GRPCConfig{}, trillianclient.DefaultConfig()) }, expectErr: true, }, { name: "Scenario 5: Trillian API Error", - setup: func(_ *testing.T, logRanges *LogRanges, tcm **trillianclient.ClientManager) { + setup: func(t *testing.T, logRanges *LogRanges, tcm **trillianclient.ClientManager) { inactive, _ := initializeRange(context.Background(), LogRange{TreeID: 501, SigningConfig: activeSC}) logRanges.inactive = Ranges{inactive} + // Create a dedicated backend that returns an error + sErr, closeErr := setupMockServer(t, mockCtl) + t.Cleanup(closeErr) + addr := sErr.Addr + port, err := strconv.Atoi(addr[strings.LastIndex(addr, ":")+1:]) + require.NoError(t, err) + // Mock an error from the Trillian server - s1.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(nil, status.Error(codes.NotFound, "tree not found")) + sErr.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(nil, status.Error(codes.NotFound, "tree not found")).MinTimes(1) grpcConfigs := map[int64]trillianclient.GRPCConfig{ - 501: {Address: "localhost", Port: uint16(port1)}, + 501: {Address: "localhost", Port: uint16(port)}, } - *tcm = trillianclient.NewClientManager(grpcConfigs, trillianclient.GRPCConfig{}) + *tcm = trillianclient.NewClientManager(grpcConfigs, trillianclient.GRPCConfig{}, trillianclient.DefaultConfig()) }, expectErr: true, }, diff --git a/pkg/trillianclient/client_interface.go b/pkg/trillianclient/client_interface.go new file mode 100644 index 000000000..dbb8b9833 --- /dev/null +++ b/pkg/trillianclient/client_interface.go @@ -0,0 +1,67 @@ +// +// Copyright 2026 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trillianclient + +import ( + "context" + + "github.com/google/trillian" + "github.com/google/trillian/types" + "google.golang.org/grpc/codes" +) + +// ClientInterface defines the public API for interacting with a Trillian log. +// Two implementations exist: +// - simpleTrillianClient: stateless, per-RPC client (default) +// - TrillianClient: cached STH client with background root updates (experimental, opt-in via CacheSTH) +type ClientInterface interface { + AddLeaf(ctx context.Context, byteValue []byte) *Response + GetLatest(ctx context.Context, firstSize int64) *Response + GetLeafAndProofByHash(ctx context.Context, hash []byte) *Response + GetLeafAndProofByIndex(ctx context.Context, index int64) *Response + GetConsistencyProof(ctx context.Context, firstSize, lastSize int64) *Response + GetLeavesByRange(ctx context.Context, startIndex, count int64) *Response + GetLeafWithoutProof(ctx context.Context, index int64) *Response + Close() +} + +// Response includes a status code, an optional error message, and one of the results based on the API call +type Response struct { + // Status is the status code of the response + Status codes.Code + // Error contains an error on request or client failure + Err error + // GetAddResult contains the response from queueing a leaf in Trillian + GetAddResult *trillian.QueueLeafResponse + // GetLeafAndProofResult contains the response for fetching an inclusion proof and leaf + GetLeafAndProofResult *trillian.GetEntryAndProofResponse + // GetLatestResult contains the response for the latest checkpoint + GetLatestResult *trillian.GetLatestSignedLogRootResponse + // GetConsistencyProofResult contains the response for a consistency proof between two log sizes + GetConsistencyProofResult *trillian.GetConsistencyProofResponse + // GetLeavesByRangeResult contains the response for fetching a leaf without an inclusion proof + GetLeavesByRangeResult *trillian.GetLeavesByRangeResponse + // getProofResult contains the response for an inclusion proof fetched by leaf hash + getProofResult *trillian.GetInclusionProofByHashResponse +} + +func unmarshalLogRoot(logRoot []byte) (types.LogRootV1, error) { + var root types.LogRootV1 + if err := root.UnmarshalBinary(logRoot); err != nil { + return types.LogRootV1{}, err + } + return root, nil +} diff --git a/pkg/trillianclient/doc.go b/pkg/trillianclient/doc.go new file mode 100644 index 000000000..c92e7483e --- /dev/null +++ b/pkg/trillianclient/doc.go @@ -0,0 +1,41 @@ +// +// Copyright 2026 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trillianclient + +// Package trillianclient provides Rekor wrappers around Trillian's gRPC API. +// +// Two client modes are supported: +// +// - simpleTrillianClient (default): stateless, per-RPC behavior with no +// background goroutines and no cached root state. +// +// - TrillianClient (enabled with --trillian_log_server.cache_sth): cached +// Signed Tree Head (STH) behavior with a background updater. +// +// In cached mode, the client keeps an atomic snapshot of the latest verified +// root and uses waiter channels to wake only callers whose requested tree size +// has been reached. +// +// Frozen trees (inactive shards) are identified through configuration and are +// treated specially: the client initializes once, does not start an updater, +// and fails fast when callers request sizes that cannot be reached. +// +// The package exposes metrics for updater health, root advancement, and waiting +// behavior to support operational monitoring. +// +// This package intentionally focuses on behavior and architecture. Any concrete +// latency or throughput expectations depend on deployment topology, Trillian +// configuration, and workload characteristics. diff --git a/pkg/trillianclient/manager.go b/pkg/trillianclient/manager.go index cdba743fd..c7291d6d3 100644 --- a/pkg/trillianclient/manager.go +++ b/pkg/trillianclient/manager.go @@ -44,8 +44,8 @@ type ClientManager struct { // Mutex for trillianClients map clientMu sync.RWMutex - // trillianClients caches the TrillianClient wrappers. - trillianClients map[int64]*TrillianClient + // trillianClients caches the client wrappers. + trillianClients map[int64]ClientInterface // flag to indicate whether the client manager is shutting down shutdown bool @@ -53,15 +53,18 @@ type ClientManager struct { treeIDToConfig map[int64]GRPCConfig // defaultConfig is the global fallback configuration. defaultConfig GRPCConfig + // clientConfig holds timeout settings for new clients + clientConfig Config } // NewClientManager creates a new ClientManager. -func NewClientManager(treeIDToConfig map[int64]GRPCConfig, defaultConfig GRPCConfig) *ClientManager { +func NewClientManager(treeIDToConfig map[int64]GRPCConfig, defaultConfig GRPCConfig, clientConfig Config) *ClientManager { return &ClientManager{ connections: make(map[GRPCConfig]*grpc.ClientConn), treeIDToConfig: treeIDToConfig, defaultConfig: defaultConfig, - trillianClients: make(map[int64]*TrillianClient), + clientConfig: clientConfig, + trillianClients: make(map[int64]ClientInterface), } } @@ -81,8 +84,28 @@ func (cm *ClientManager) getConn(treeID int64) (*grpc.ClientConn, error) { return conn, nil } + // Check shutdown before dialing. Read clientMu outside connMu to + // maintain consistent lock ordering (GetTrillianClient acquires + // clientMu then calls getConn which acquires connMu). + cm.clientMu.RLock() + shutting := cm.shutdown + cm.clientMu.RUnlock() + if shutting { + return nil, errors.New("client manager is shutting down") + } + cm.connMu.Lock() defer cm.connMu.Unlock() + + // Re-check shutdown after acquiring connMu. Close() may have run + // between the early check and here, draining all connections. + cm.clientMu.RLock() + shutting = cm.shutdown + cm.clientMu.RUnlock() + if shutting { + return nil, errors.New("client manager is shutting down") + } + // Double-check after acquiring the write lock. conn, ok = cm.connections[config] if ok { @@ -99,16 +122,17 @@ func (cm *ClientManager) getConn(treeID int64) (*grpc.ClientConn, error) { } // GetTrillianClient returns a Rekor Trillian client wrapper for the given tree ID. -func (cm *ClientManager) GetTrillianClient(treeID int64) (*TrillianClient, error) { +// When CacheSTH is enabled, returns a cached STH client; otherwise returns a simple per-RPC client. +func (cm *ClientManager) GetTrillianClient(treeID int64) (ClientInterface, error) { cm.clientMu.RLock() if cm.shutdown { cm.clientMu.RUnlock() return nil, errors.New("client manager is shutting down") } - client, ok := cm.trillianClients[treeID] + c, ok := cm.trillianClients[treeID] cm.clientMu.RUnlock() if ok { - return client, nil + return c, nil } conn, err := cm.getConn(treeID) @@ -122,11 +146,16 @@ func (cm *ClientManager) GetTrillianClient(treeID int64) (*TrillianClient, error if cm.shutdown { return nil, errors.New("client manager is shutting down") } - if client, ok = cm.trillianClients[treeID]; ok { - return client, nil + if c, ok = cm.trillianClients[treeID]; ok { + return c, nil } - newClient := newTrillianClient(trillian.NewTrillianLogClient(conn), treeID) + var newClient ClientInterface + if cm.clientConfig.CacheSTH { + newClient = newTrillianClient(trillian.NewTrillianLogClient(conn), treeID, cm.clientConfig) + } else { + newClient = newSimpleTrillianClient(trillian.NewTrillianLogClient(conn), treeID) + } cm.trillianClients[treeID] = newClient return newClient, nil } @@ -200,10 +229,11 @@ func dial(hostname string, port uint16, tlsCACertFile string, useSystemTrustStor func (cm *ClientManager) Close() error { var err error - // set shutdown flag to true and clear cache of clients + // Lock ordering: clientMu then connMu (same as GetTrillianClient → getConn). cm.clientMu.Lock() cm.shutdown = true - cm.trillianClients = make(map[int64]*TrillianClient) + oldClients := cm.trillianClients + cm.trillianClients = make(map[int64]ClientInterface) cm.clientMu.Unlock() cm.connMu.Lock() @@ -214,5 +244,10 @@ func (cm *ClientManager) Close() error { delete(cm.connections, cfg) } cm.connMu.Unlock() + + // Close clients outside both locks to avoid deadlock (client.Close may block). + for _, c := range oldClients { + c.Close() + } return err } diff --git a/pkg/trillianclient/trillian_client.go b/pkg/trillianclient/trillian_client.go index 237c266a9..5c03d463a 100644 --- a/pkg/trillianclient/trillian_client.go +++ b/pkg/trillianclient/trillian_client.go @@ -20,9 +20,10 @@ import ( "context" "encoding/hex" "fmt" + "sync" + "sync/atomic" "time" - "github.com/transparency-dev/merkle/proof" "github.com/transparency-dev/merkle/rfc6962" "google.golang.org/grpc/codes" @@ -30,63 +31,351 @@ import ( "github.com/google/trillian" "github.com/google/trillian/client" + "github.com/google/trillian/client/backoff" "github.com/google/trillian/types" + "github.com/prometheus/client_golang/prometheus" + "github.com/sigstore/rekor/pkg/log" ) -// TrillianClient provides a wrapper around the Trillian client +// Default timeouts for initialization and updater polling. +// These can be overridden via TrillianClientConfig. +const ( + DefaultInitLatestRootTimeout = 3 * time.Second + DefaultUpdaterWaitTimeout = 3 * time.Second +) + +var ( + metricRootAdvance = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "rekor_trillian_root_advance_total", + Help: "Number of root advances observed by the Trillian client.", + }, + []string{"tree"}, + ) + metricUpdaterErrors = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "rekor_trillian_updater_errors_total", + Help: "Total updater errors (wait/fetch/marshal).", + }, + []string{"tree"}, + ) + metricLatestTreeSize = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "rekor_trillian_latest_tree_size", + Help: "Latest observed tree size per tree.", + }, + []string{"tree"}, + ) + metricWaitForRootAtLeast = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "rekor_trillian_wait_for_root_ms", + Help: "Time spent waiting for the root to reach at least a given size (ms).", + Buckets: prometheus.ExponentialBuckets(1, 2, 12), + }, + []string{"tree", "success"}, + ) + metricInclusionWait = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "rekor_trillian_inclusion_wait_ms", + Help: "Time to obtain an inclusion proof (ms).", + Buckets: prometheus.ExponentialBuckets(1, 2, 12), + }, + []string{"success"}, + ) +) + +func init() { + // Register metrics once. + prometheus.MustRegister( + metricRootAdvance, + metricUpdaterErrors, + metricLatestTreeSize, + metricWaitForRootAtLeast, + metricInclusionWait, + ) +} + +// Config holds configuration options for TrillianClient +type Config struct { + // CacheSTH enables the cached STH client with background root updates (experimental). + // When false (default), the simple per-RPC client is used. + CacheSTH bool + // InitLatestRootTimeout is the timeout for fetching the latest root during initialization + InitLatestRootTimeout time.Duration + // UpdaterWaitTimeout is the timeout for updater polling wait operations + UpdaterWaitTimeout time.Duration + // FrozenTreeIDs contains tree IDs of frozen (inactive) logs. When CacheSTH is + // enabled, cached clients for frozen trees fetch the root once and never start + // a background updater. + FrozenTreeIDs map[int64]bool +} + +// DefaultConfig returns a config with default timeout values +func DefaultConfig() Config { + return Config{ + InitLatestRootTimeout: DefaultInitLatestRootTimeout, + UpdaterWaitTimeout: DefaultUpdaterWaitTimeout, + } +} + +// waiter represents a caller blocked in waitForRootAtLeast. +type waiter struct { + ch chan struct{} + size uint64 +} + +// TrillianClient provides a cached STH wrapper around the Trillian client +// with background root updates and channel-per-caller notification. type TrillianClient struct { client trillian.TrillianLogClient logID int64 + config Config + frozen bool // when true, the tree is frozen; no updater is started + + // shared trillian client/verifier + lc *client.LogClient + v *client.LogVerifier + mu sync.Mutex + waiters []waiter // per-caller notification channels + wg sync.WaitGroup + + // cached root snapshot (atomic for read-heavy paths) + snapshot atomic.Value // stores rootSnapshot + + // lifecycle + started bool + startErr error + stopCh chan struct{} + + // bgCtx is canceled on Close to interrupt long waits in the updater. + bgCtx context.Context + bgCancel context.CancelFunc } -// newTrillianClient creates a TrillianClient with the given Trillian client and log/tree ID. -func newTrillianClient(logClient trillian.TrillianLogClient, logID int64) *TrillianClient { - return &TrillianClient{ +type rootSnapshot struct { + root types.LogRootV1 + signed *trillian.SignedLogRoot +} + +// newTrillianClient creates a TrillianClient with the given Trillian client, log/tree ID, and config. +// If the tree ID appears in config.FrozenTreeIDs, the client fetches the root once during +// initialization and never starts the background updater, avoiding wasted RPCs on trees +// that will never advance. +func newTrillianClient(logClient trillian.TrillianLogClient, logID int64, config Config) *TrillianClient { + t := &TrillianClient{ client: logClient, logID: logID, - } + config: config, + frozen: config.FrozenTreeIDs[logID], + stopCh: make(chan struct{}), + } + t.bgCtx, t.bgCancel = context.WithCancel(context.Background()) + // initialize atomic snapshot with zero value + t.snapshot.Store(rootSnapshot{}) + return t } -// Response includes a status code, an optional error message, and one of the results based on the API call -type Response struct { - // Status is the status code of the response - Status codes.Code - // Error contains an error on request or client failure - Err error - // GetAddResult contains the response from queueing a leaf in Trillian - GetAddResult *trillian.QueueLeafResponse - // GetLeafAndProofResult contains the response for fetching an inclusion proof and leaf - GetLeafAndProofResult *trillian.GetEntryAndProofResponse - // GetLatestResult contains the response for the latest checkpoint - GetLatestResult *trillian.GetLatestSignedLogRootResponse - // GetConsistencyProofResult contains the response for a consistency proof between two log sizes - GetConsistencyProofResult *trillian.GetConsistencyProofResponse - // GetLeavesByRangeResult contains the response for fetching a leaf without an inclusion proof - GetLeavesByRangeResult *trillian.GetLeavesByRangeResponse - // getProofResult contains the response for an inclusion proof fetched by leaf hash - getProofResult *trillian.GetInclusionProofByHashResponse +// registerWaiter adds a new waiter for the given tree size and returns its channel. +// Must be called with t.mu held. +func (t *TrillianClient) registerWaiter(size uint64) chan struct{} { + ch := make(chan struct{}, 1) + t.waiters = append(t.waiters, waiter{ch: ch, size: size}) + return ch } -func unmarshalLogRoot(logRoot []byte) (types.LogRootV1, error) { - var root types.LogRootV1 - if err := root.UnmarshalBinary(logRoot); err != nil { - return types.LogRootV1{}, err +// removeWaiter removes a waiter by its channel (used for cleanup on context cancellation). +// Must be called with t.mu held. +func (t *TrillianClient) removeWaiter(ch chan struct{}) { + for i, w := range t.waiters { + if w.ch == ch { + t.waiters[i] = t.waiters[len(t.waiters)-1] + t.waiters = t.waiters[:len(t.waiters)-1] + return + } } - return root, nil } -func (t *TrillianClient) root(ctx context.Context) (types.LogRootV1, error) { - rqst := &trillian.GetLatestSignedLogRootRequest{ - LogId: t.logID, +// notifyWaiters closes the channels of all waiters whose requested size +// is satisfied by newSize, and removes them from the slice. +// Must be called with t.mu held. +func (t *TrillianClient) notifyWaiters(newSize uint64) { + remaining := t.waiters[:0] + for _, w := range t.waiters { + if newSize >= w.size { + close(w.ch) + } else { + remaining = append(remaining, w) + } + } + t.waiters = remaining +} + +// ensureStarted initializes the shared LogClient and starts the updater once. +// The mutex serializes concurrent callers so only one performs the initial RPC; +// subsequent callers return the cached result immediately. +func (t *TrillianClient) ensureStarted(ctx context.Context) error { + t.mu.Lock() + defer t.mu.Unlock() + if t.started { + return t.startErr } - resp, err := t.client.GetLatestSignedLogRoot(ctx, rqst) + + // Perform one-time initialization while holding the lock. + // This blocks other ensureStarted callers until initialization completes. + cctx := ctx + var cancel context.CancelFunc + if _, ok := ctx.Deadline(); !ok { + cctx, cancel = context.WithTimeout(ctx, t.config.InitLatestRootTimeout) + } + if cancel != nil { + defer cancel() + } + slr, err := t.client.GetLatestSignedLogRoot(cctx, &trillian.GetLatestSignedLogRootRequest{LogId: t.logID}) if err != nil { - return types.LogRootV1{}, err + t.startErr = err + return err + } + if slr == nil || slr.SignedLogRoot == nil { + err = fmt.Errorf("nil signed log root") + t.startErr = err + return err + } + r, uerr := unmarshalLogRoot(slr.SignedLogRoot.LogRoot) + if uerr != nil { + t.startErr = uerr + return uerr + } + + t.v = client.NewLogVerifier(rfc6962.DefaultHasher) + t.lc = client.New(t.logID, t.client, t.v, r) + t.snapshot.Store(rootSnapshot{root: r, signed: slr.SignedLogRoot}) + t.started = true + t.startErr = nil + + // Start updater only for non-frozen trees. Frozen trees never advance, + // so polling would waste RPCs, goroutines, and pollute error metrics. + if !t.frozen { + t.wg.Add(1) + go func() { + defer t.wg.Done() + t.updater() + }() + } + return nil +} + +// updater waits for root changes using the LogClient and notifies waiters. +// It uses the parsed root from WaitForRootUpdate and synthesizes a minimal +// SignedLogRoot (LogRoot bytes only) to avoid an extra network round trip +// per advancement. +func (t *TrillianClient) updater() { + // Create backoff for retry logic with reasonable defaults + bo := backoff.Backoff{ + Min: 100 * time.Millisecond, // Start with 100ms + Max: 30 * time.Second, // Cap at 30s + Factor: 2.0, // Double each time + Jitter: true, // Add randomization + } + for { + // Wrap the WaitForRootUpdate call with backoff retry + var nr *types.LogRootV1 + err := bo.Retry(t.bgCtx, func() error { + select { + case <-t.stopCh: + return fmt.Errorf("client stopped") + default: + } + + ctx, cancel := context.WithTimeout(t.bgCtx, t.config.UpdaterWaitTimeout) + defer cancel() + + var waitErr error + nr, waitErr = t.lc.WaitForRootUpdate(ctx) + return waitErr + }) + select { + case <-t.stopCh: + return + default: + } + + if err != nil { + log.Logger.Debugw("trillian root update wait failed after retries", "treeID", t.logID, "err", err) + metricUpdaterErrors.WithLabelValues(fmt.Sprintf("%d", t.logID)).Inc() + // Do not reset backoff on error; let it accumulate for persistent failures. + continue + } + + // Success - reset backoff for next potential failure + bo.Reset() + + if nr == nil { + continue + } + + // compute change against current snapshot + old := t.snapshot.Load().(rootSnapshot) + changed := nr.TreeSize != old.root.TreeSize || !bytes.Equal(nr.RootHash, old.root.RootHash) + if !changed { + // nothing to publish + continue + } + log.Logger.Debugw("trillian root advanced", "treeID", t.logID, "oldSize", old.root.TreeSize, "newSize", nr.TreeSize) + + // Marshal parsed root to bytes and synthesize a minimal SignedLogRoot + lrBytes, mErr := nr.MarshalBinary() + if mErr != nil { + log.Logger.Debugw("failed to marshal updated log root", "treeID", t.logID, "err", mErr) + metricUpdaterErrors.WithLabelValues(fmt.Sprintf("%d", t.logID)).Inc() + continue + } + slr := &trillian.SignedLogRoot{LogRoot: lrBytes} + + // publish new snapshot and notify waiters + t.mu.Lock() + t.snapshot.Store(rootSnapshot{root: *nr, signed: slr}) + t.notifyWaiters(nr.TreeSize) + t.mu.Unlock() + + // metrics + tree := fmt.Sprintf("%d", t.logID) + metricRootAdvance.WithLabelValues(tree).Inc() + metricLatestTreeSize.WithLabelValues(tree).Set(float64(nr.TreeSize)) } - return unmarshalLogRoot(resp.SignedLogRoot.LogRoot) +} + +// Close stops the updater and unblocks all waiters. +func (t *TrillianClient) Close() { + t.mu.Lock() + // Cancel background operations first to unblock any waits + if t.bgCancel != nil { + t.bgCancel() + } + select { + case <-t.stopCh: + default: + close(t.stopCh) + } + // Close all waiter channels to unblock them + for _, w := range t.waiters { + close(w.ch) + } + t.waiters = nil + t.mu.Unlock() + // Wait for updater to exit + t.wg.Wait() } func (t *TrillianClient) AddLeaf(ctx context.Context, byteValue []byte) *Response { + if err := t.ensureStarted(ctx); err != nil { + return &Response{ + Status: status.Code(err), + Err: err, + } + } + // Capture baseline tree size before queueing to set the first gate correctly. + preSnap, _ := t.snapshot.Load().(rootSnapshot) + baselineSize := preSnap.root.TreeSize leaf := &trillian.LogLeaf{ LeafValue: byteValue, } @@ -95,59 +384,35 @@ func (t *TrillianClient) AddLeaf(ctx context.Context, byteValue []byte) *Respons Leaf: leaf, } resp, err := t.client.QueueLeaf(ctx, rqst) - - // check for error - if err != nil || (resp.QueuedLeaf.Status != nil && resp.QueuedLeaf.Status.Code != int32(codes.OK)) { + if err != nil { return &Response{ Status: status.Code(err), Err: err, GetAddResult: resp, } } - - root, err := t.root(ctx) - if err != nil { + if resp == nil || resp.QueuedLeaf == nil || resp.QueuedLeaf.Leaf == nil { return &Response{ - Status: status.Code(err), - Err: err, - GetAddResult: resp, + Status: codes.Internal, + Err: fmt.Errorf("unexpected nil in QueueLeaf response"), } } - v := client.NewLogVerifier(rfc6962.DefaultHasher) - logClient := client.New(t.logID, t.client, v, root) - - waitForInclusion := func(ctx context.Context, _ []byte) *Response { - if logClient.MinMergeDelay > 0 { - select { - case <-ctx.Done(): - return &Response{ - Status: codes.DeadlineExceeded, - Err: ctx.Err(), - } - case <-time.After(logClient.MinMergeDelay): - } - } - for { - root = *logClient.GetRoot() - if root.TreeSize >= 1 { - proofResp := t.getProofByHash(ctx, resp.QueuedLeaf.Leaf.MerkleLeafHash) - // if this call succeeds or returns an error other than "not found", return - if proofResp.Err == nil || (proofResp.Err != nil && status.Code(proofResp.Err) != codes.NotFound) { - return proofResp - } - // otherwise wait for a root update before trying again - } - - if _, err := logClient.WaitForRootUpdate(ctx); err != nil { - return &Response{ - Status: codes.Unknown, - Err: err, - } - } + // Non-OK insertion status (e.g. ALREADY_EXISTS) is not a gRPC error. + // Return Status: OK with the response so callers can inspect QueuedLeaf.Status + // to determine the insertion-level outcome (e.g. HTTP 409 for duplicates). + if resp.QueuedLeaf.Status != nil && resp.QueuedLeaf.Status.Code != int32(codes.OK) { + return &Response{ + Status: codes.OK, + GetAddResult: resp, } } - proofResp := waitForInclusion(ctx, resp.QueuedLeaf.Leaf.MerkleLeafHash) + // Gate the first proof attempt on the next root advance relative to the + // snapshot observed here. This avoids an almost-always NotFound on the + // very first try and trims unnecessary RPCs without impacting latency + // (we need a root advance to include the leaf anyway). + minSize := baselineSize + 1 + proofResp := t.waitForInclusionWithMinSize(ctx, resp.QueuedLeaf.Leaf.MerkleLeafHash, minSize) if proofResp.Err != nil { return &Response{ Status: status.Code(proofResp.Err), @@ -188,15 +453,22 @@ func (t *TrillianClient) AddLeaf(ctx context.Context, byteValue []byte) *Respons } func (t *TrillianClient) GetLeafAndProofByHash(ctx context.Context, hash []byte) *Response { - // get inclusion proof for hash, extract index, then fetch leaf using index - proofResp := t.getProofByHash(ctx, hash) + if err := t.ensureStarted(ctx); err != nil { + return &Response{ + Status: status.Code(err), + Err: err, + } + } + snap, _ := t.snapshot.Load().(rootSnapshot) + root := snap.root + signed := snap.signed + proofResp := t.getProofByHashWithRoot(ctx, hash, root, signed) if proofResp.Err != nil { return &Response{ Status: status.Code(proofResp.Err), Err: proofResp.Err, } } - proofs := proofResp.getProofResult.Proof if len(proofs) != 1 { err := fmt.Errorf("expected 1 proof from getProofByHash for %v, found %v", hex.EncodeToString(hash), len(proofs)) @@ -220,75 +492,86 @@ func (t *TrillianClient) GetLeafAndProofByHash(ctx context.Context, hash []byte) } func (t *TrillianClient) GetLeafAndProofByIndex(ctx context.Context, index int64) *Response { - rootResp := t.GetLatest(ctx, 0) - if rootResp.Err != nil { + if err := t.ensureStarted(ctx); err != nil { return &Response{ - Status: status.Code(rootResp.Err), - Err: rootResp.Err, + Status: status.Code(err), + Err: err, } } + snap, _ := t.snapshot.Load().(rootSnapshot) + root := snap.root + signed := snap.signed - root, err := unmarshalLogRoot(rootResp.GetLatestResult.SignedLogRoot.LogRoot) + resp, err := t.client.GetEntryAndProof(ctx, &trillian.GetEntryAndProofRequest{ + LogId: t.logID, + LeafIndex: index, + TreeSize: int64(root.TreeSize), + }) if err != nil { return &Response{ Status: status.Code(err), Err: err, } } - - resp, err := t.client.GetEntryAndProof(ctx, - &trillian.GetEntryAndProofRequest{ - LogId: t.logID, - LeafIndex: index, - TreeSize: int64(root.TreeSize), //nolint:gosec - }) - if resp != nil && resp.Proof != nil { - if err := proof.VerifyInclusion(rfc6962.DefaultHasher, uint64(index), root.TreeSize, resp.GetLeaf().MerkleLeafHash, resp.Proof.Hashes, root.RootHash); err != nil { //nolint:gosec + if err := t.v.VerifyInclusionByHash(&root, resp.GetLeaf().MerkleLeafHash, resp.Proof); err != nil { return &Response{ Status: status.Code(err), Err: err, } } return &Response{ - Status: status.Code(err), - Err: err, + Status: codes.OK, GetLeafAndProofResult: &trillian.GetEntryAndProofResponse{ Proof: resp.Proof, Leaf: resp.Leaf, - SignedLogRoot: rootResp.GetLatestResult.SignedLogRoot, + SignedLogRoot: signed, }, } } - return &Response{ - Status: status.Code(err), - Err: err, + Status: codes.NotFound, + Err: fmt.Errorf("trillian returned empty response for index %d", index), } } -func (t *TrillianClient) GetLatest(ctx context.Context, leafSizeInt int64) *Response { - resp, err := t.client.GetLatestSignedLogRoot(ctx, - &trillian.GetLatestSignedLogRootRequest{ - LogId: t.logID, - FirstTreeSize: leafSizeInt, - }) - +func (t *TrillianClient) GetLatest(ctx context.Context, firstSize int64) *Response { + if err := t.ensureStarted(ctx); err != nil { + return &Response{ + Status: status.Code(err), + Err: err, + } + } + if firstSize > 0 { + if err := t.waitForRootAtLeast(ctx, uint64(firstSize)); err != nil { + return &Response{ + Status: status.Code(err), + Err: err, + } + } + } + snap, _ := t.snapshot.Load().(rootSnapshot) + signed := snap.signed + if signed == nil { + return &Response{ + Status: codes.NotFound, + Err: status.Error(codes.NotFound, "no signed root available"), + } + } return &Response{ - Status: status.Code(err), - Err: err, - GetLatestResult: resp, + Status: codes.OK, + GetLatestResult: &trillian.GetLatestSignedLogRootResponse{ + SignedLogRoot: signed, + }, } } func (t *TrillianClient) GetConsistencyProof(ctx context.Context, firstSize, lastSize int64) *Response { - resp, err := t.client.GetConsistencyProof(ctx, - &trillian.GetConsistencyProofRequest{ - LogId: t.logID, - FirstTreeSize: firstSize, - SecondTreeSize: lastSize, - }) - + resp, err := t.client.GetConsistencyProof(ctx, &trillian.GetConsistencyProofRequest{ + LogId: t.logID, + FirstTreeSize: firstSize, + SecondTreeSize: lastSize, + }) return &Response{ Status: status.Code(err), Err: err, @@ -296,61 +579,153 @@ func (t *TrillianClient) GetConsistencyProof(ctx context.Context, firstSize, las } } -func (t *TrillianClient) getProofByHash(ctx context.Context, hashValue []byte) *Response { - rootResp := t.GetLatest(ctx, 0) - if rootResp.Err != nil { +func (t *TrillianClient) getProofByHashWithRoot(ctx context.Context, hashValue []byte, root types.LogRootV1, signed *trillian.SignedLogRoot) *Response { + // issue 1308: if the tree is empty, there's no way we can return a proof + if root.TreeSize == 0 { return &Response{ - Status: status.Code(rootResp.Err), - Err: rootResp.Err, + Status: codes.NotFound, + Err: status.Error(codes.NotFound, "tree is empty"), } } - root, err := unmarshalLogRoot(rootResp.GetLatestResult.SignedLogRoot.LogRoot) + resp, err := t.client.GetInclusionProofByHash(ctx, &trillian.GetInclusionProofByHashRequest{ + LogId: t.logID, + LeafHash: hashValue, + TreeSize: int64(root.TreeSize), //nolint:gosec + }) if err != nil { return &Response{ Status: status.Code(err), Err: err, } } - - // issue 1308: if the tree is empty, there's no way we can return a proof - if root.TreeSize == 0 { - return &Response{ - Status: codes.NotFound, - Err: status.Error(codes.NotFound, "tree is empty"), - } - } - - resp, err := t.client.GetInclusionProofByHash(ctx, - &trillian.GetInclusionProofByHashRequest{ - LogId: t.logID, - LeafHash: hashValue, - TreeSize: int64(root.TreeSize), //nolint:gosec - }) - if resp != nil { - v := client.NewLogVerifier(rfc6962.DefaultHasher) - for _, proof := range resp.Proof { - if err := v.VerifyInclusionByHash(&root, hashValue, proof); err != nil { + for _, p := range resp.Proof { + if err := t.v.VerifyInclusionByHash(&root, hashValue, p); err != nil { return &Response{ Status: status.Code(err), Err: err, } } } - // Return an inclusion proof response with the requested return &Response{ - Status: status.Code(err), - Err: err, + Status: codes.OK, getProofResult: &trillian.GetInclusionProofByHashResponse{ Proof: resp.Proof, - SignedLogRoot: rootResp.GetLatestResult.SignedLogRoot, + SignedLogRoot: signed, }, } } - return &Response{ - Status: status.Code(err), - Err: err, + Status: codes.Unknown, + Err: fmt.Errorf("trillian returned empty proof for hash %s", hex.EncodeToString(hashValue)), + } +} + +// waitForInclusionWithMinSize behaves like waitForInclusion but ensures the +// first inclusion-proof attempt happens only after the tree has reached at +// least minSize. This reduces initial NotFound churn without increasing time +// to success (since inclusion requires a root advance). +func (t *TrillianClient) waitForInclusionWithMinSize(ctx context.Context, leafHash []byte, minSize uint64) *Response { + start := time.Now() + + // Optionally delay the very first attempt until minSize is reached. + // If the current snapshot is already beyond minSize, this returns immediately. + if err := t.waitForRootAtLeast(ctx, minSize); err != nil { + elapsed := float64(time.Since(start).Milliseconds()) + metricInclusionWait.WithLabelValues("false").Observe(elapsed) + return &Response{Status: status.Code(err), Err: err} + } + + for { + if err := ctx.Err(); err != nil { + elapsed := float64(time.Since(start).Milliseconds()) + metricInclusionWait.WithLabelValues("false").Observe(elapsed) + return &Response{Status: status.Code(err), Err: err} + } + snap, _ := t.snapshot.Load().(rootSnapshot) + root := snap.root + signed := snap.signed + + proofResp := t.getProofByHashWithRoot(ctx, leafHash, root, signed) + if proofResp.Err == nil || status.Code(proofResp.Err) != codes.NotFound { + success := proofResp.Err == nil + elapsed := float64(time.Since(start).Milliseconds()) + metricInclusionWait.WithLabelValues(fmt.Sprintf("%t", success)).Observe(elapsed) + return proofResp + } + + // NotFound: wait for the tree to grow and try again + if err := t.waitForRootAtLeast(ctx, root.TreeSize+1); err != nil { + return &Response{Status: status.Code(err), Err: err} + } + } +} + +// waitForRootAtLeast blocks until the cached root TreeSize >= size, or context/client closes. +// For frozen trees, returns immediately with an error if the current size is insufficient, +// since the tree will never advance. +func (t *TrillianClient) waitForRootAtLeast(ctx context.Context, size uint64) error { + start := time.Now() + tree := fmt.Sprintf("%d", t.logID) + + // Fast path: check without lock + cur := t.snapshot.Load().(rootSnapshot) + if cur.root.TreeSize >= size { + elapsed := float64(time.Since(start).Milliseconds()) + metricWaitForRootAtLeast.WithLabelValues(tree, "true").Observe(elapsed) + return nil + } + + // Frozen trees will never advance; fail immediately rather than blocking forever. + if t.frozen { + elapsed := float64(time.Since(start).Milliseconds()) + metricWaitForRootAtLeast.WithLabelValues(tree, "false").Observe(elapsed) + return status.Errorf(codes.FailedPrecondition, "tree %d is frozen at size %d, requested %d", t.logID, cur.root.TreeSize, size) + } + + // Register waiter + t.mu.Lock() + // Re-check under lock (snapshot may have advanced) + cur = t.snapshot.Load().(rootSnapshot) + if cur.root.TreeSize >= size { + t.mu.Unlock() + elapsed := float64(time.Since(start).Milliseconds()) + metricWaitForRootAtLeast.WithLabelValues(tree, "true").Observe(elapsed) + return nil + } + ch := t.registerWaiter(size) + t.mu.Unlock() + + // Wait on channel, context, or stop. + // When multiple channels fire simultaneously, select picks one + // non-deterministically. After receiving from ch, we re-check stopCh + // to avoid returning success during shutdown. + select { + case <-ch: + select { + case <-t.stopCh: + elapsed := float64(time.Since(start).Milliseconds()) + metricWaitForRootAtLeast.WithLabelValues(tree, "false").Observe(elapsed) + return status.Error(codes.Canceled, "client closed") + default: + } + elapsed := float64(time.Since(start).Milliseconds()) + metricWaitForRootAtLeast.WithLabelValues(tree, "true").Observe(elapsed) + return nil + case <-ctx.Done(): + t.mu.Lock() + t.removeWaiter(ch) + t.mu.Unlock() + elapsed := float64(time.Since(start).Milliseconds()) + metricWaitForRootAtLeast.WithLabelValues(tree, "false").Observe(elapsed) + return ctx.Err() + case <-t.stopCh: + t.mu.Lock() + t.removeWaiter(ch) + t.mu.Unlock() + elapsed := float64(time.Since(start).Milliseconds()) + metricWaitForRootAtLeast.WithLabelValues(tree, "false").Observe(elapsed) + return status.Error(codes.Canceled, "client closed") } } diff --git a/pkg/trillianclient/trillian_client_simple.go b/pkg/trillianclient/trillian_client_simple.go new file mode 100644 index 000000000..557e573b7 --- /dev/null +++ b/pkg/trillianclient/trillian_client_simple.go @@ -0,0 +1,403 @@ +// +// Copyright 2021 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trillianclient + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "time" + + "github.com/transparency-dev/merkle/proof" + "github.com/transparency-dev/merkle/rfc6962" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/google/trillian" + "github.com/google/trillian/client" + "github.com/google/trillian/types" +) + +// simpleTrillianClient is a stateless, per-RPC wrapper around the Trillian gRPC +// client. It fetches a fresh root on every operation that requires one, with no +// background goroutines or cached state. +type simpleTrillianClient struct { + client trillian.TrillianLogClient + logID int64 +} + +// newSimpleTrillianClient creates a simpleTrillianClient. +func newSimpleTrillianClient(logClient trillian.TrillianLogClient, logID int64) *simpleTrillianClient { + return &simpleTrillianClient{ + client: logClient, + logID: logID, + } +} + +func (t *simpleTrillianClient) root(ctx context.Context) (types.LogRootV1, error) { + rqst := &trillian.GetLatestSignedLogRootRequest{ + LogId: t.logID, + } + resp, err := t.client.GetLatestSignedLogRoot(ctx, rqst) + if err != nil { + return types.LogRootV1{}, err + } + if resp == nil || resp.SignedLogRoot == nil { + return types.LogRootV1{}, fmt.Errorf("nil signed log root in response") + } + return unmarshalLogRoot(resp.SignedLogRoot.LogRoot) +} + +func (t *simpleTrillianClient) AddLeaf(ctx context.Context, byteValue []byte) *Response { + leaf := &trillian.LogLeaf{ + LeafValue: byteValue, + } + rqst := &trillian.QueueLeafRequest{ + LogId: t.logID, + Leaf: leaf, + } + resp, err := t.client.QueueLeaf(ctx, rqst) + if err != nil { + return &Response{ + Status: status.Code(err), + Err: err, + GetAddResult: resp, + } + } + if resp == nil || resp.QueuedLeaf == nil || resp.QueuedLeaf.Leaf == nil { + return &Response{ + Status: codes.Internal, + Err: fmt.Errorf("unexpected nil in QueueLeaf response"), + } + } + // Non-OK insertion status (e.g. ALREADY_EXISTS) is not a gRPC error. + // Return Status: OK with the response so callers can inspect QueuedLeaf.Status + // to determine the insertion-level outcome (e.g. HTTP 409 for duplicates). + if resp.QueuedLeaf.Status != nil && resp.QueuedLeaf.Status.Code != int32(codes.OK) { + return &Response{ + Status: codes.OK, + GetAddResult: resp, + } + } + + root, err := t.root(ctx) + if err != nil { + return &Response{ + Status: status.Code(err), + Err: err, + GetAddResult: resp, + } + } + v := client.NewLogVerifier(rfc6962.DefaultHasher) + logClient := client.New(t.logID, t.client, v, root) + + waitForInclusion := func(ctx context.Context) *Response { + if logClient.MinMergeDelay > 0 { + select { + case <-ctx.Done(): + return &Response{ + Status: codes.DeadlineExceeded, + Err: ctx.Err(), + } + case <-time.After(logClient.MinMergeDelay): + } + } + for { + root = *logClient.GetRoot() + if root.TreeSize >= 1 { + proofResp := t.getProofByHash(ctx, resp.QueuedLeaf.Leaf.MerkleLeafHash) + if proofResp.Err == nil || (proofResp.Err != nil && status.Code(proofResp.Err) != codes.NotFound) { + return proofResp + } + } + + if _, err := logClient.WaitForRootUpdate(ctx); err != nil { + return &Response{ + Status: codes.Unknown, + Err: err, + } + } + } + } + + proofResp := waitForInclusion(ctx) + if proofResp.Err != nil { + return &Response{ + Status: status.Code(proofResp.Err), + Err: proofResp.Err, + GetAddResult: resp, + } + } + + proofs := proofResp.getProofResult.Proof + if len(proofs) != 1 { + err := fmt.Errorf("expected 1 proof from getProofByHash for %v, found %v", hex.EncodeToString(resp.QueuedLeaf.Leaf.MerkleLeafHash), len(proofs)) + return &Response{ + Status: status.Code(err), + Err: err, + GetAddResult: resp, + } + } + + leafIndex := proofs[0].LeafIndex + leafOnlyResp := t.getStandaloneLeaf(ctx, leafIndex, resp.QueuedLeaf.Leaf.MerkleLeafHash, proofs[0], proofResp.getProofResult.SignedLogRoot) + if leafOnlyResp.Err != nil { + return &Response{ + Status: status.Code(leafOnlyResp.Err), + Err: leafOnlyResp.Err, + GetAddResult: resp, + } + } + + resp.QueuedLeaf.Leaf = leafOnlyResp.GetLeafAndProofResult.Leaf + + return &Response{ + Status: codes.OK, + GetAddResult: resp, + GetLeafAndProofResult: leafOnlyResp.GetLeafAndProofResult, + } +} + +func (t *simpleTrillianClient) GetLeafAndProofByHash(ctx context.Context, hash []byte) *Response { + proofResp := t.getProofByHash(ctx, hash) + if proofResp.Err != nil { + return &Response{ + Status: status.Code(proofResp.Err), + Err: proofResp.Err, + } + } + + proofs := proofResp.getProofResult.Proof + if len(proofs) != 1 { + err := fmt.Errorf("expected 1 proof from getProofByHash for %v, found %v", hex.EncodeToString(hash), len(proofs)) + return &Response{ + Status: status.Code(err), + Err: err, + } + } + + leafIndex := proofs[0].LeafIndex + leafOnlyResp := t.getStandaloneLeaf(ctx, leafIndex, hash, proofs[0], proofResp.getProofResult.SignedLogRoot) + if leafOnlyResp.Err != nil { + return &Response{ + Status: status.Code(leafOnlyResp.Err), + Err: leafOnlyResp.Err, + } + } + + return leafOnlyResp +} + +func (t *simpleTrillianClient) GetLeafAndProofByIndex(ctx context.Context, index int64) *Response { + rootResp := t.GetLatest(ctx, 0) + if rootResp.Err != nil { + return &Response{ + Status: status.Code(rootResp.Err), + Err: rootResp.Err, + } + } + + root, err := unmarshalLogRoot(rootResp.GetLatestResult.SignedLogRoot.LogRoot) + if err != nil { + return &Response{ + Status: status.Code(err), + Err: err, + } + } + + resp, err := t.client.GetEntryAndProof(ctx, + &trillian.GetEntryAndProofRequest{ + LogId: t.logID, + LeafIndex: index, + TreeSize: int64(root.TreeSize), + }) + + if resp != nil && resp.Proof != nil { + if err := proof.VerifyInclusion(rfc6962.DefaultHasher, uint64(index), root.TreeSize, resp.GetLeaf().MerkleLeafHash, resp.Proof.Hashes, root.RootHash); err != nil { + return &Response{ + Status: status.Code(err), + Err: err, + } + } + return &Response{ + Status: status.Code(err), + Err: err, + GetLeafAndProofResult: &trillian.GetEntryAndProofResponse{ + Proof: resp.Proof, + Leaf: resp.Leaf, + SignedLogRoot: rootResp.GetLatestResult.SignedLogRoot, + }, + } + } + + return &Response{ + Status: status.Code(err), + Err: err, + } +} + +func (t *simpleTrillianClient) GetLatest(ctx context.Context, leafSizeInt int64) *Response { + resp, err := t.client.GetLatestSignedLogRoot(ctx, + &trillian.GetLatestSignedLogRootRequest{ + LogId: t.logID, + FirstTreeSize: leafSizeInt, + }) + + return &Response{ + Status: status.Code(err), + Err: err, + GetLatestResult: resp, + } +} + +func (t *simpleTrillianClient) GetConsistencyProof(ctx context.Context, firstSize, lastSize int64) *Response { + resp, err := t.client.GetConsistencyProof(ctx, + &trillian.GetConsistencyProofRequest{ + LogId: t.logID, + FirstTreeSize: firstSize, + SecondTreeSize: lastSize, + }) + + return &Response{ + Status: status.Code(err), + Err: err, + GetConsistencyProofResult: resp, + } +} + +func (t *simpleTrillianClient) getProofByHash(ctx context.Context, hashValue []byte) *Response { + rootResp := t.GetLatest(ctx, 0) + if rootResp.Err != nil { + return &Response{ + Status: status.Code(rootResp.Err), + Err: rootResp.Err, + } + } + root, err := unmarshalLogRoot(rootResp.GetLatestResult.SignedLogRoot.LogRoot) + if err != nil { + return &Response{ + Status: status.Code(err), + Err: err, + } + } + + if root.TreeSize == 0 { + return &Response{ + Status: codes.NotFound, + Err: status.Error(codes.NotFound, "tree is empty"), + } + } + + resp, err := t.client.GetInclusionProofByHash(ctx, + &trillian.GetInclusionProofByHashRequest{ + LogId: t.logID, + LeafHash: hashValue, + TreeSize: int64(root.TreeSize), + }) + + if resp != nil { + v := client.NewLogVerifier(rfc6962.DefaultHasher) + for _, p := range resp.Proof { + if err := v.VerifyInclusionByHash(&root, hashValue, p); err != nil { + return &Response{ + Status: status.Code(err), + Err: err, + } + } + } + return &Response{ + Status: status.Code(err), + Err: err, + getProofResult: &trillian.GetInclusionProofByHashResponse{ + Proof: resp.Proof, + SignedLogRoot: rootResp.GetLatestResult.SignedLogRoot, + }, + } + } + + return &Response{ + Status: status.Code(err), + Err: err, + } +} + +// GetLeavesByRange fetches leaves from startIndex (inclusive) up to count leaves without proofs. +func (t *simpleTrillianClient) GetLeavesByRange(ctx context.Context, startIndex, count int64) *Response { + resp, err := t.client.GetLeavesByRange(ctx, &trillian.GetLeavesByRangeRequest{ + LogId: t.logID, + StartIndex: startIndex, + Count: count, + }) + return &Response{ + Status: status.Code(err), + Err: err, + GetLeavesByRangeResult: resp, + } +} + +// GetLeafWithoutProof is a convenience wrapper for fetching a single leaf by index without proofs. +func (t *simpleTrillianClient) GetLeafWithoutProof(ctx context.Context, index int64) *Response { + return t.GetLeavesByRange(ctx, index, 1) +} + +// Close is a no-op for the simple client (no background goroutines). +func (t *simpleTrillianClient) Close() {} + +// getStandaloneLeaf gets just the leaf, returns it in GetLeafAndProof result for easier reuse. +func (t *simpleTrillianClient) getStandaloneLeaf(ctx context.Context, index int64, hash []byte, p *trillian.Proof, signedRoot *trillian.SignedLogRoot) *Response { + leafOnlyResp := t.GetLeafWithoutProof(ctx, index) + if leafOnlyResp.Err != nil { + return &Response{ + Status: status.Code(leafOnlyResp.Err), + Err: leafOnlyResp.Err, + } + } + + if leafOnlyResp.GetLeavesByRangeResult == nil || len(leafOnlyResp.GetLeavesByRangeResult.Leaves) == 0 { + err := fmt.Errorf("no leaf returned for index %d", index) + return &Response{ + Status: codes.NotFound, + Err: err, + } + } + if len(leafOnlyResp.GetLeavesByRangeResult.Leaves) != 1 { + err := fmt.Errorf("multiple leaves returned for index %d", index) + return &Response{ + Status: codes.FailedPrecondition, + Err: err, + } + } + leaf := leafOnlyResp.GetLeavesByRangeResult.Leaves[0] + + if !bytes.Equal(leaf.MerkleLeafHash, hash) { + err := fmt.Errorf("leaf hash mismatch: expected %v, got %v", hex.EncodeToString(hash), hex.EncodeToString(leaf.MerkleLeafHash)) + return &Response{ + Status: status.Code(err), + Err: err, + } + } + + return &Response{ + Status: codes.OK, + GetLeafAndProofResult: &trillian.GetEntryAndProofResponse{ + Proof: p, + Leaf: leaf, + SignedLogRoot: signedRoot, + }, + } +} diff --git a/pkg/trillianclient/trillian_client_simple_test.go b/pkg/trillianclient/trillian_client_simple_test.go new file mode 100644 index 000000000..b28a9800a --- /dev/null +++ b/pkg/trillianclient/trillian_client_simple_test.go @@ -0,0 +1,185 @@ +// +// Copyright 2025 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trillianclient + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/trillian" + "github.com/google/trillian/testonly" + "github.com/google/trillian/types" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" +) + +func TestSimpleClient_GetLatest(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + slr := mkSLR(t, 5, make([]byte, 32)) + s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return( + &trillian.GetLatestSignedLogRootResponse{SignedLogRoot: slr}, nil, + ).Times(1) + + conn := dialMock(t, s.Addr) + tc := newSimpleTrillianClient(trillian.NewTrillianLogClient(conn), 42) + + resp := tc.GetLatest(context.Background(), 0) + require.NoError(t, resp.Err) + require.Equal(t, codes.OK, resp.Status) + require.NotNil(t, resp.GetLatestResult) + require.NotNil(t, resp.GetLatestResult.SignedLogRoot) + + var got types.LogRootV1 + require.NoError(t, got.UnmarshalBinary(resp.GetLatestResult.SignedLogRoot.LogRoot)) + require.EqualValues(t, 5, got.TreeSize) +} + +func TestSimpleClient_GetLatest_WithFirstTreeSize(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + slr := mkSLR(t, 10, make([]byte, 32)) + s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, req *trillian.GetLatestSignedLogRootRequest) (*trillian.GetLatestSignedLogRootResponse, error) { + // Verify FirstTreeSize is passed through to Trillian + require.EqualValues(t, 5, req.FirstTreeSize) + return &trillian.GetLatestSignedLogRootResponse{SignedLogRoot: slr}, nil + }, + ).Times(1) + + conn := dialMock(t, s.Addr) + tc := newSimpleTrillianClient(trillian.NewTrillianLogClient(conn), 42) + + resp := tc.GetLatest(context.Background(), 5) + require.NoError(t, resp.Err) + require.Equal(t, codes.OK, resp.Status) +} + +func TestSimpleClient_Close_DoesNotAffectRPCs(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + slr := mkSLR(t, 3, make([]byte, 32)) + s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return( + &trillian.GetLatestSignedLogRootResponse{SignedLogRoot: slr}, nil, + ).Times(1) + + conn := dialMock(t, s.Addr) + tc := newSimpleTrillianClient(trillian.NewTrillianLogClient(conn), 42) + + // Close is intentionally a no-op for the simple client. + tc.Close() + + resp := tc.GetLatest(context.Background(), 0) + require.NoError(t, resp.Err) + require.Equal(t, codes.OK, resp.Status) +} + +func TestSimpleClient_GetConsistencyProof(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + s.Log.EXPECT().GetConsistencyProof(gomock.Any(), gomock.Any()).Return( + &trillian.GetConsistencyProofResponse{ + Proof: &trillian.Proof{Hashes: [][]byte{make([]byte, 32)}}, + }, nil, + ).Times(1) + + conn := dialMock(t, s.Addr) + tc := newSimpleTrillianClient(trillian.NewTrillianLogClient(conn), 42) + + resp := tc.GetConsistencyProof(context.Background(), 1, 5) + require.NoError(t, resp.Err) + require.Equal(t, codes.OK, resp.Status) + require.NotNil(t, resp.GetConsistencyProofResult) +} + +func TestSimpleClient_GetLeavesByRange(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + s.Log.EXPECT().GetLeavesByRange(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, req *trillian.GetLeavesByRangeRequest) (*trillian.GetLeavesByRangeResponse, error) { + require.EqualValues(t, 42, req.LogId) + require.EqualValues(t, 0, req.StartIndex) + require.EqualValues(t, 1, req.Count) + return &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{{LeafIndex: 0, MerkleLeafHash: make([]byte, 32)}}, + }, nil + }, + ).Times(1) + + conn := dialMock(t, s.Addr) + tc := newSimpleTrillianClient(trillian.NewTrillianLogClient(conn), 42) + + resp := tc.GetLeavesByRange(context.Background(), 0, 1) + require.NoError(t, resp.Err) + require.Equal(t, codes.OK, resp.Status) + require.NotNil(t, resp.GetLeavesByRangeResult) + require.Len(t, resp.GetLeavesByRangeResult.Leaves, 1) +} + +func TestSimpleClient_GetLeafWithoutProof_DelegatesToRange(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + s.Log.EXPECT().GetLeavesByRange(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, req *trillian.GetLeavesByRangeRequest) (*trillian.GetLeavesByRangeResponse, error) { + require.EqualValues(t, 42, req.LogId) + require.EqualValues(t, 7, req.StartIndex) + require.EqualValues(t, 1, req.Count) + return &trillian.GetLeavesByRangeResponse{ + Leaves: []*trillian.LogLeaf{{LeafIndex: 7, MerkleLeafHash: make([]byte, 32)}}, + }, nil + }, + ).Times(1) + + conn := dialMock(t, s.Addr) + tc := newSimpleTrillianClient(trillian.NewTrillianLogClient(conn), 42) + + resp := tc.GetLeafWithoutProof(context.Background(), 7) + require.NoError(t, resp.Err) + require.Equal(t, codes.OK, resp.Status) + require.NotNil(t, resp.GetLeavesByRangeResult) + require.Len(t, resp.GetLeavesByRangeResult.Leaves, 1) +} diff --git a/pkg/trillianclient/trillian_client_test.go b/pkg/trillianclient/trillian_client_test.go new file mode 100644 index 000000000..cba94513a --- /dev/null +++ b/pkg/trillianclient/trillian_client_test.go @@ -0,0 +1,1045 @@ +// +// Copyright 2025 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trillianclient + +import ( + "bytes" + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/trillian" + "github.com/google/trillian/client" + "github.com/google/trillian/testonly" + "github.com/google/trillian/types" + "github.com/stretchr/testify/require" + "github.com/transparency-dev/merkle/rfc6962" + "go.uber.org/goleak" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +// helper to build a SignedLogRoot with given tree size and root hash +func mkSLR(t *testing.T, size uint64, rootHash []byte) *trillian.SignedLogRoot { + t.Helper() + lr := &types.LogRootV1{TreeSize: size, RootHash: rootHash} + b, err := lr.MarshalBinary() + require.NoError(t, err) + return &trillian.SignedLogRoot{LogRoot: b} +} + +func dialMock(t *testing.T, addr string) *grpc.ClientConn { + t.Helper() + conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + return conn +} + +// advanceRoot updates the cached snapshot and notifies waiters via the channel-per-caller mechanism. +func advanceRoot(t *testing.T, tc *TrillianClient, size uint64, rootHash []byte) { + t.Helper() + lr := &types.LogRootV1{TreeSize: size, RootHash: rootHash} + b, err := lr.MarshalBinary() + require.NoError(t, err) + tc.mu.Lock() + tc.snapshot.Store(rootSnapshot{root: *lr, signed: &trillian.SignedLogRoot{LogRoot: b}}) + tc.notifyWaiters(size) + tc.mu.Unlock() +} + +type fakeCloseTrackingClient struct { + closeCalls int32 +} + +func (f *fakeCloseTrackingClient) AddLeaf(_ context.Context, _ []byte) *Response { + return &Response{Status: codes.OK} +} + +func (f *fakeCloseTrackingClient) GetLatest(_ context.Context, _ int64) *Response { + return &Response{Status: codes.OK} +} + +func (f *fakeCloseTrackingClient) GetLeafAndProofByHash(_ context.Context, _ []byte) *Response { + return &Response{Status: codes.OK} +} + +func (f *fakeCloseTrackingClient) GetLeafAndProofByIndex(_ context.Context, _ int64) *Response { + return &Response{Status: codes.OK} +} + +func (f *fakeCloseTrackingClient) GetConsistencyProof(_ context.Context, _, _ int64) *Response { + return &Response{Status: codes.OK} +} + +func (f *fakeCloseTrackingClient) GetLeavesByRange(_ context.Context, _, _ int64) *Response { + return &Response{Status: codes.OK} +} + +func (f *fakeCloseTrackingClient) GetLeafWithoutProof(_ context.Context, _ int64) *Response { + return &Response{Status: codes.OK} +} + +func (f *fakeCloseTrackingClient) Close() { + atomic.AddInt32(&f.closeCalls, 1) +} + +func (f *fakeCloseTrackingClient) CloseCalls() int32 { + return atomic.LoadInt32(&f.closeCalls) +} + +func TestEnsureStartedAndGetLatest(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + // Initial root (empty tree) + slr := mkSLR(t, 0, make([]byte, 32)) + s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: slr}, nil).MinTimes(1) + + conn := dialMock(t, s.Addr) + tc := newTrillianClient(trillian.NewTrillianLogClient(conn), 42, DefaultConfig()) + t.Cleanup(tc.Close) + + resp := tc.GetLatest(context.Background(), 0) + require.NoError(t, resp.Err) + require.Equal(t, codes.OK, resp.Status) + require.NotNil(t, resp.GetLatestResult) + require.NotNil(t, resp.GetLatestResult.SignedLogRoot) + + // Unmarshal and check size + var got types.LogRootV1 + require.NoError(t, got.UnmarshalBinary(resp.GetLatestResult.SignedLogRoot.LogRoot)) + require.EqualValues(t, 0, got.TreeSize) +} + +// Note: waiting for an advance via client.WaitForRootUpdate is exercised indirectly +// in other tests (AddLeaf), and is hard to deterministically simulate across +// environments with the mock server; we avoid a direct "firstSize" wait test here. + +func TestGetLeafAndProofByIndex_VerifiesProof(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + // Tree of size 1, root equals leaf hash. Empty proof should verify. + rootHash := make([]byte, 32) + slr1 := mkSLR(t, 1, rootHash) + s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: slr1}, nil).MinTimes(1) + + s.Log.EXPECT().GetEntryAndProof(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, r *trillian.GetEntryAndProofRequest) (*trillian.GetEntryAndProofResponse, error) { + // Ensure we were asked for the current tree size + if r.TreeSize != 1 || r.LeafIndex != 0 { + return nil, status.Error(codes.InvalidArgument, "unexpected request") + } + return &trillian.GetEntryAndProofResponse{ + Leaf: &trillian.LogLeaf{MerkleLeafHash: rootHash}, + Proof: &trillian.Proof{LeafIndex: 0, Hashes: nil}, + }, nil + }, + ).Times(1) + + conn := dialMock(t, s.Addr) + tc := newTrillianClient(trillian.NewTrillianLogClient(conn), 9, DefaultConfig()) + t.Cleanup(tc.Close) + + resp := tc.GetLeafAndProofByIndex(context.Background(), 0) + require.NoError(t, resp.Err) + require.Equal(t, codes.OK, resp.Status) + require.NotNil(t, resp.GetLeafAndProofResult) +} + +func TestGetLeafAndProofByHash_VerifiesProof(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + rootHash := make([]byte, 32) + slr1 := mkSLR(t, 1, rootHash) + s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: slr1}, nil).MinTimes(1) + + // Inclusion proof for hash -> index 0, empty path is valid in size=1 + s.Log.EXPECT().GetInclusionProofByHash(gomock.Any(), gomock.Any()).Return( + &trillian.GetInclusionProofByHashResponse{Proof: []*trillian.Proof{{LeafIndex: 0, Hashes: nil}}}, nil, + ).Times(1) + + // GetLeafAndProofByHash now calls GetLeavesByRange to fetch the leaf + s.Log.EXPECT().GetLeavesByRange(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, r *trillian.GetLeavesByRangeRequest) (*trillian.GetLeavesByRangeResponse, error) { + if r.Count != 1 || r.StartIndex != 0 { + return nil, status.Error(codes.InvalidArgument, "unexpected range request") + } + return &trillian.GetLeavesByRangeResponse{Leaves: []*trillian.LogLeaf{{MerkleLeafHash: rootHash}}}, nil + }, + ).Times(1) + + conn := dialMock(t, s.Addr) + tc := newTrillianClient(trillian.NewTrillianLogClient(conn), 13, DefaultConfig()) + t.Cleanup(tc.Close) + + resp := tc.GetLeafAndProofByHash(context.Background(), rootHash) + require.NoError(t, resp.Err) + require.Equal(t, codes.OK, resp.Status) + require.NotNil(t, resp.GetLeafAndProofResult) +} + +func TestAddLeaf_HappyPath(t *testing.T) { + opt := goleak.IgnoreCurrent() + t.Cleanup(func() { goleak.VerifyNone(t, opt) }) + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + leafHash := make([]byte, 32) // leaf 0 hash + + // We'll simulate a root advance from size=0 to size=2 and return a + // proof for leaf index 0 with a single sibling. Compute a consistent + // root hash for size=2 so verification succeeds. + sibling := bytes.Repeat([]byte{0x7f}, 32) // arbitrary sibling hash + root2 := rfc6962.DefaultHasher.HashChildren(leafHash, sibling) + slr0 := mkSLR(t, 0, make([]byte, 32)) + + // QueueLeaf returns quickly + s.Log.EXPECT().QueueLeaf(gomock.Any(), gomock.Any()).Return(&trillian.QueueLeafResponse{ + QueuedLeaf: &trillian.QueuedLogLeaf{Leaf: &trillian.LogLeaf{MerkleLeafHash: leafHash}}, + }, nil).Times(1) + + // We bypass ensureStarted's network init and the updater by pre-initializing + // the client snapshot and verifier, then manually advancing the snapshot. + + // Inclusion proof by hash: success for size=2 with sibling path + s.Log.EXPECT().GetInclusionProofByHash(gomock.Any(), gomock.Any()).Return( + &trillian.GetInclusionProofByHashResponse{Proof: []*trillian.Proof{{LeafIndex: 0, Hashes: [][]byte{sibling}}}}, nil, + ).Times(1) + + // After inclusion, client fetches leaf by index without proof to get server-populated fields + s.Log.EXPECT().GetLeavesByRange(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, r *trillian.GetLeavesByRangeRequest) (*trillian.GetLeavesByRangeResponse, error) { + if r.Count != 1 || r.StartIndex != 0 { + return nil, status.Error(codes.InvalidArgument, "unexpected range request") + } + return &trillian.GetLeavesByRangeResponse{Leaves: []*trillian.LogLeaf{{MerkleLeafHash: leafHash}}}, nil + }, + ).Times(1) + + conn := dialMock(t, s.Addr) + tc := newTrillianClient(trillian.NewTrillianLogClient(conn), 21, DefaultConfig()) + // Pre-initialize + tc.started = true + tc.v = client.NewLogVerifier(rfc6962.DefaultHasher) + tc.snapshot.Store(rootSnapshot{root: types.LogRootV1{TreeSize: 0, RootHash: make([]byte, 32)}, signed: slr0}) + // Advance snapshot to size=2 after a short delay to release waiters + go func() { + time.Sleep(20 * time.Millisecond) + advanceRoot(t, tc, 2, root2) + }() + t.Cleanup(tc.Close) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + resp := tc.AddLeaf(ctx, []byte("hello")) + require.NoError(t, resp.Err) + require.Equal(t, codes.OK, resp.Status) + require.NotNil(t, resp.GetAddResult) + require.NotNil(t, resp.GetLeafAndProofResult) +} + +func TestGetLatestFirstSizeCanceledOnClose(t *testing.T) { + opt := goleak.IgnoreCurrent() + t.Cleanup(func() { goleak.VerifyNone(t, opt) }) + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + // Always return size=0 so waiter would block + slr0 := mkSLR(t, 0, make([]byte, 32)) + s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: slr0}, nil).MinTimes(1) + + conn := dialMock(t, s.Addr) + tc := newTrillianClient(trillian.NewTrillianLogClient(conn), 33, DefaultConfig()) + + done := make(chan *Response, 1) + go func() { + done <- tc.GetLatest(context.Background(), 1) // would block until size>=1 + }() + + // Give it a moment to start waiting + time.Sleep(50 * time.Millisecond) + tc.Close() + + select { + case r := <-done: + require.Error(t, r.Err) + require.Equal(t, codes.Canceled, r.Status) + case <-time.After(2 * time.Second): + t.Fatal("GetLatest did not return after Close") + } +} + +func TestEnsureStartedError(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(nil, status.Error(codes.Unavailable, "boom")).Times(1) + + conn := dialMock(t, s.Addr) + tc := newTrillianClient(trillian.NewTrillianLogClient(conn), 99, DefaultConfig()) + t.Cleanup(tc.Close) + + resp := tc.GetLatest(context.Background(), 0) + require.Error(t, resp.Err) + require.Equal(t, codes.Unavailable, resp.Status) +} + +func TestWaitForRootAtLeast_BroadcastWakesAll(t *testing.T) { + opt := goleak.IgnoreCurrent() + t.Cleanup(func() { goleak.VerifyNone(t, opt) }) + tc := newTrillianClient(nil, 100, DefaultConfig()) + // Start with size 0 + tc.snapshot.Store(rootSnapshot{root: types.LogRootV1{TreeSize: 0}}) + t.Cleanup(tc.Close) + + const numWaiters = 10 + var wg sync.WaitGroup + wg.Add(numWaiters) + + errs := make(chan error, numWaiters) + for i := 0; i < numWaiters; i++ { + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + errs <- tc.waitForRootAtLeast(ctx, 5) + }() + } + + // Give goroutines time to register as waiters + time.Sleep(20 * time.Millisecond) + + // Publish new root and notify waiters + advanceRoot(t, tc, 5, make([]byte, 32)) + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + + select { + case <-done: + close(errs) + for e := range errs { + require.NoError(t, e) + } + case <-time.After(2 * time.Second): + t.Fatal("waiters did not unblock after notification") + } +} + +func TestGetLatest_WithFirstSize_BroadcastWakesAll(t *testing.T) { + opt := goleak.IgnoreCurrent() + t.Cleanup(func() { goleak.VerifyNone(t, opt) }) + tc := newTrillianClient(nil, 101, DefaultConfig()) + // Mark as started to bypass network init in GetLatest + tc.started = true + // initial snapshot with size 0 + tc.snapshot.Store(rootSnapshot{root: types.LogRootV1{TreeSize: 0}, signed: mkSLR(t, 0, make([]byte, 32))}) + t.Cleanup(tc.Close) + + const numWaiters = 8 + var wg sync.WaitGroup + wg.Add(numWaiters) + + results := make(chan *Response, numWaiters) + for i := 0; i < numWaiters; i++ { + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + results <- tc.GetLatest(ctx, 5) + }() + } + + // Small delay to let goroutines register waiters + time.Sleep(20 * time.Millisecond) + + // Publish root size 5 and notify waiters + advanceRoot(t, tc, 5, make([]byte, 32)) + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + + select { + case <-done: + close(results) + for r := range results { + require.NoError(t, r.Err) + require.Equal(t, codes.OK, r.Status) + require.NotNil(t, r.GetLatestResult) + } + case <-time.After(2 * time.Second): + t.Fatal("GetLatest waiters did not unblock after notification") + } +} + +func TestEnsureStarted_SingleRPCWithFanIn(t *testing.T) { + opt := goleak.IgnoreCurrent() + t.Cleanup(func() { goleak.VerifyNone(t, opt) }) + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + slr := mkSLR(t, 0, make([]byte, 32)) + s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, _ *trillian.GetLatestSignedLogRootRequest) (*trillian.GetLatestSignedLogRootResponse, error) { + time.Sleep(30 * time.Millisecond) + return &trillian.GetLatestSignedLogRootResponse{SignedLogRoot: slr}, nil + }, + ).Times(1) + + conn := dialMock(t, s.Addr) + cfg := DefaultConfig() + cfg.FrozenTreeIDs = map[int64]bool{222: true} // prevents updater RPC noise + tc := newTrillianClient(trillian.NewTrillianLogClient(conn), 222, cfg) + t.Cleanup(tc.Close) + + const n = 20 + var wg sync.WaitGroup + wg.Add(n) + errs := make(chan error, n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + r := tc.GetLatest(ctx, 0) + errs <- r.Err + }() + } + wg.Wait() + close(errs) + for e := range errs { + require.NoError(t, e) + } +} + +func TestWaitForRootAtLeast_SpuriousBroadcastIgnored(t *testing.T) { + opt := goleak.IgnoreCurrent() + t.Cleanup(func() { goleak.VerifyNone(t, opt) }) + tc := newTrillianClient(nil, 303, DefaultConfig()) + tc.snapshot.Store(rootSnapshot{root: types.LogRootV1{TreeSize: 1}}) + t.Cleanup(tc.Close) + + const numWaiters = 6 + var wg sync.WaitGroup + wg.Add(numWaiters) + results := make(chan error, numWaiters) + for i := 0; i < numWaiters; i++ { + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + results <- tc.waitForRootAtLeast(ctx, 5) + }() + } + + // Give them time to register as waiters + time.Sleep(20 * time.Millisecond) + + // With channel-per-caller, there is no "spurious" broadcast; waiters are only + // notified when their target size is met. Advance to size 3 (below target 5); + // waiters for size 5 should not be notified. + advanceRoot(t, tc, 3, make([]byte, 32)) + time.Sleep(30 * time.Millisecond) + require.Zero(t, len(results), "waiters should not exit when size is below target") + + // Now increase size to 5; everyone should complete + advanceRoot(t, tc, 5, make([]byte, 32)) + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + close(results) + for e := range results { + require.NoError(t, e) + } + case <-time.After(2 * time.Second): + t.Fatal("waiters did not complete after size increased") + } +} + +func TestSnapshotConcurrentReadersWriters_NoDataRace(t *testing.T) { + opt := goleak.IgnoreCurrent() + t.Cleanup(func() { goleak.VerifyNone(t, opt) }) + tc := newTrillianClient(nil, 404, DefaultConfig()) + tc.started = true + // Provide a minimal signed root so GetLatest can return without NotFound + tc.snapshot.Store(rootSnapshot{root: types.LogRootV1{TreeSize: 0}, signed: mkSLR(t, 0, make([]byte, 32))}) + t.Cleanup(tc.Close) + + stop := make(chan struct{}) + + // Writer rapidly updates snapshot, notifying waiters each time + go func() { + ticker := time.NewTicker(1 * time.Millisecond) + defer ticker.Stop() + sz := uint64(0) + for i := 0; i < 100; i++ { + <-ticker.C + sz++ + lr := &types.LogRootV1{TreeSize: sz} + b, _ := lr.MarshalBinary() + tc.mu.Lock() + tc.snapshot.Store(rootSnapshot{root: *lr, signed: &trillian.SignedLogRoot{LogRoot: b}}) + tc.notifyWaiters(sz) + tc.mu.Unlock() + } + close(stop) + }() + + // Readers call GetLatest repeatedly + const readers = 16 + var wg sync.WaitGroup + wg.Add(readers) + for i := 0; i < readers; i++ { + go func() { + defer wg.Done() + for { + select { + case <-stop: + return + default: + } + r := tc.GetLatest(context.Background(), 0) + require.NoError(t, r.Err) + require.NotNil(t, r.GetLatestResult) + } + }() + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + // success + case <-time.After(500 * time.Millisecond): + t.Fatal("concurrent readers did not complete in time") + } +} + +func TestEnsureStartedDeadlineRespected(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + // Server handler sleeps longer than client deadline; client should return DeadlineExceeded. + // Use AnyTimes because the deadline may expire before the RPC reaches the server. + s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, _ *trillian.GetLatestSignedLogRootRequest) (*trillian.GetLatestSignedLogRootResponse, error) { + time.Sleep(200 * time.Millisecond) + return &trillian.GetLatestSignedLogRootResponse{SignedLogRoot: mkSLR(t, 0, make([]byte, 32))}, nil + }, + ).AnyTimes() + + conn := dialMock(t, s.Addr) + tc := newTrillianClient(trillian.NewTrillianLogClient(conn), 606, DefaultConfig()) + t.Cleanup(tc.Close) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + resp := tc.GetLatest(ctx, 0) + require.Error(t, resp.Err) + require.Equal(t, codes.DeadlineExceeded, resp.Status) +} + +// --- New tests for channel-per-caller and edge cases --- + +func TestWaitForRootAtLeast_AlreadySatisfied(t *testing.T) { + tc := newTrillianClient(nil, 500, DefaultConfig()) + tc.snapshot.Store(rootSnapshot{root: types.LogRootV1{TreeSize: 10}}) + t.Cleanup(tc.Close) + + err := tc.waitForRootAtLeast(context.Background(), 5) + require.NoError(t, err) + + err = tc.waitForRootAtLeast(context.Background(), 10) + require.NoError(t, err) +} + +func TestWaitForRootAtLeast_ContextCancellation(t *testing.T) { + opt := goleak.IgnoreCurrent() + t.Cleanup(func() { goleak.VerifyNone(t, opt) }) + tc := newTrillianClient(nil, 501, DefaultConfig()) + tc.snapshot.Store(rootSnapshot{root: types.LogRootV1{TreeSize: 0}}) + t.Cleanup(tc.Close) + + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan error, 1) + go func() { + done <- tc.waitForRootAtLeast(ctx, 100) + }() + + // Give the goroutine time to register as a waiter + time.Sleep(20 * time.Millisecond) + + // Cancel context - should immediately unblock + cancel() + + select { + case err := <-done: + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + case <-time.After(500 * time.Millisecond): + t.Fatal("waiter was not unblocked by context cancellation") + } +} + +func TestClose_UnblocksAllWaiters(t *testing.T) { + opt := goleak.IgnoreCurrent() + t.Cleanup(func() { goleak.VerifyNone(t, opt) }) + tc := newTrillianClient(nil, 502, DefaultConfig()) + tc.snapshot.Store(rootSnapshot{root: types.LogRootV1{TreeSize: 0}}) + + const numWaiters = 5 + var wg sync.WaitGroup + wg.Add(numWaiters) + errs := make(chan error, numWaiters) + + for i := 0; i < numWaiters; i++ { + go func() { + defer wg.Done() + errs <- tc.waitForRootAtLeast(context.Background(), 999) + }() + } + + // Give goroutines time to register + time.Sleep(20 * time.Millisecond) + tc.Close() + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + + select { + case <-done: + close(errs) + for e := range errs { + require.Error(t, e) + } + case <-time.After(2 * time.Second): + t.Fatal("Close did not unblock all waiters") + } +} + +func TestNotifyWaiters_PartialSatisfaction(t *testing.T) { + tc := newTrillianClient(nil, 503, DefaultConfig()) + tc.snapshot.Store(rootSnapshot{root: types.LogRootV1{TreeSize: 0}}) + t.Cleanup(tc.Close) + + tc.mu.Lock() + ch3 := tc.registerWaiter(3) + ch5 := tc.registerWaiter(5) + ch10 := tc.registerWaiter(10) + tc.mu.Unlock() + + // Notify with size 5: should satisfy waiters for 3 and 5, but not 10 + tc.mu.Lock() + tc.notifyWaiters(5) + tc.mu.Unlock() + + // ch3 and ch5 should be closed (readable immediately) + select { + case <-ch3: + // expected + default: + t.Fatal("waiter for size 3 should have been notified") + } + select { + case <-ch5: + // expected + default: + t.Fatal("waiter for size 5 should have been notified") + } + + // ch10 should NOT be closed + select { + case <-ch10: + t.Fatal("waiter for size 10 should NOT have been notified") + default: + // expected + } + + // Verify remaining waiters count + tc.mu.Lock() + require.Len(t, tc.waiters, 1) + require.Equal(t, uint64(10), tc.waiters[0].size) + tc.mu.Unlock() +} + +func TestRemoveWaiter_Cleanup(t *testing.T) { + tc := newTrillianClient(nil, 504, DefaultConfig()) + tc.snapshot.Store(rootSnapshot{root: types.LogRootV1{TreeSize: 0}}) + t.Cleanup(tc.Close) + + tc.mu.Lock() + ch1 := tc.registerWaiter(5) + ch2 := tc.registerWaiter(10) + require.Len(t, tc.waiters, 2) + + tc.removeWaiter(ch1) + require.Len(t, tc.waiters, 1) + require.Equal(t, ch2, tc.waiters[0].ch) + + // Remove non-existent channel is a no-op + tc.removeWaiter(make(chan struct{})) + require.Len(t, tc.waiters, 1) + tc.mu.Unlock() +} + +func TestUpdater_RetriesOnTransientErrors(t *testing.T) { + opt := goleak.IgnoreCurrent() + t.Cleanup(func() { goleak.VerifyNone(t, opt) }) + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + rootHash1 := bytes.Repeat([]byte{0x11}, 32) + + var latestCalls int32 + s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, _ *trillian.GetLatestSignedLogRootRequest) (*trillian.GetLatestSignedLogRootResponse, error) { + atomic.AddInt32(&latestCalls, 1) + return nil, status.Error(codes.Unavailable, "transient") + }, + ).AnyTimes() + + conn := dialMock(t, s.Addr) + tc := newTrillianClient(trillian.NewTrillianLogClient(conn), 600, Config{ + UpdaterWaitTimeout: 50 * time.Millisecond, + }) + t.Cleanup(tc.Close) + + initial := types.LogRootV1{TreeSize: 1, RootHash: rootHash1} + tc.v = client.NewLogVerifier(rfc6962.DefaultHasher) + tc.lc = client.New(tc.logID, tc.client, tc.v, initial) + tc.snapshot.Store(rootSnapshot{root: initial, signed: mkSLR(t, 1, rootHash1)}) + + done := make(chan struct{}) + go func() { + tc.updater() + close(done) + }() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if atomic.LoadInt32(&latestCalls) >= 2 { + break + } + time.Sleep(10 * time.Millisecond) + } + + require.GreaterOrEqual(t, atomic.LoadInt32(&latestCalls), int32(2), "updater should keep retrying after transient errors") + + tc.Close() + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("updater did not stop after Close") + } +} + +func TestClientManager_CachesClientPerTreeID(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + cfg := Config{CacheSTH: false} + cm := NewClientManager(nil, GRPCConfig{Address: s.Addr, Port: 0}, cfg) + + conn := dialMock(t, s.Addr) + cm.connMu.Lock() + cm.connections[cm.defaultConfig] = conn + cm.connMu.Unlock() + t.Cleanup(func() { _ = cm.Close() }) + + c1, err := cm.GetTrillianClient(7) + require.NoError(t, err) + c2, err := cm.GetTrillianClient(7) + require.NoError(t, err) + c3, err := cm.GetTrillianClient(8) + require.NoError(t, err) + + require.Same(t, c1, c2, "same tree ID should return cached client instance") + require.NotSame(t, c1, c3, "different tree IDs should return distinct client instances") +} + +func TestClientManagerClose_ClosesClients(t *testing.T) { + cm := NewClientManager(nil, GRPCConfig{Address: "localhost", Port: 0}, Config{}) + fake1 := &fakeCloseTrackingClient{} + fake2 := &fakeCloseTrackingClient{} + + cm.clientMu.Lock() + cm.trillianClients[1] = fake1 + cm.trillianClients[2] = fake2 + cm.clientMu.Unlock() + + err := cm.Close() + require.NoError(t, err) + require.EqualValues(t, 1, fake1.CloseCalls(), "Close should be called on cached client 1") + require.EqualValues(t, 1, fake2.CloseCalls(), "Close should be called on cached client 2") + + cm.clientMu.RLock() + require.True(t, cm.shutdown) + require.Empty(t, cm.trillianClients) + cm.clientMu.RUnlock() + + // After Close, GetTrillianClient should fail + _, err = cm.GetTrillianClient(1) + require.Error(t, err) + require.Contains(t, err.Error(), "shutting down") +} + +func TestClientManagerGetConn_RejectsDialAfterClose(t *testing.T) { + // Verify that getConn refuses to dial after Close has drained connections, + // even if the early shutdown check passed before Close ran. + cfg := Config{CacheSTH: false} + cm := NewClientManager(nil, GRPCConfig{Address: "localhost", Port: 0}, cfg) + + // Close drains connections and sets shutdown. + require.NoError(t, cm.Close()) + + // getConn must reject the dial attempt despite the connections map being empty. + _, err := cm.getConn(1) + require.Error(t, err) + require.Contains(t, err.Error(), "shutting down") +} + +func TestClientManagerGetConn_ConcurrentCloseNeverLeaks(t *testing.T) { + opt := goleak.IgnoreCurrent() + t.Cleanup(func() { goleak.VerifyNone(t, opt) }) + + cfg := Config{CacheSTH: false} + cm := NewClientManager(nil, GRPCConfig{Address: "localhost", Port: 0}, cfg) + + // Race getConn against Close. getConn must either succeed (connection + // stored and later cleaned up) or return a shutdown error. It must + // never leave an orphaned connection. + const goroutines = 20 + errs := make(chan error, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + for range goroutines { + go func() { + defer wg.Done() + _, err := cm.getConn(1) + errs <- err + }() + } + + // Close concurrently. + closeErr := cm.Close() + wg.Wait() + close(errs) + + require.NoError(t, closeErr) + for e := range errs { + if e != nil { + require.Contains(t, e.Error(), "shutting down") + } + } + + // After Close, connections map must be empty (no leaked connections). + cm.connMu.RLock() + require.Empty(t, cm.connections) + cm.connMu.RUnlock() +} + +func TestClientManagerFactory_SimpleClient(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + cfg := Config{CacheSTH: false} + cm := NewClientManager(nil, GRPCConfig{Address: s.Addr, Port: 0}, cfg) + + // Manually inject a connection since we can't dial properly in tests + conn := dialMock(t, s.Addr) + cm.connMu.Lock() + cm.connections[cm.defaultConfig] = conn + cm.connMu.Unlock() + t.Cleanup(func() { _ = cm.Close() }) + + c, err := cm.GetTrillianClient(1) + require.NoError(t, err) + _, ok := c.(*simpleTrillianClient) + require.True(t, ok, "expected simpleTrillianClient when CacheSTH=false") +} + +func TestClientManagerFactory_CachedClient(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + cfg := Config{ + CacheSTH: true, + InitLatestRootTimeout: DefaultInitLatestRootTimeout, + UpdaterWaitTimeout: DefaultUpdaterWaitTimeout, + } + cm := NewClientManager(nil, GRPCConfig{Address: s.Addr, Port: 0}, cfg) + + // Manually inject a connection + conn := dialMock(t, s.Addr) + cm.connMu.Lock() + cm.connections[cm.defaultConfig] = conn + cm.connMu.Unlock() + t.Cleanup(func() { _ = cm.Close() }) + + c, err := cm.GetTrillianClient(1) + require.NoError(t, err) + _, ok := c.(*TrillianClient) + require.True(t, ok, "expected *TrillianClient when CacheSTH=true") +} + +// --- Frozen tree tests --- + +func TestFrozenClient_NoUpdaterStarted(t *testing.T) { + opt := goleak.IgnoreCurrent() + t.Cleanup(func() { goleak.VerifyNone(t, opt) }) + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + slr := mkSLR(t, 10, make([]byte, 32)) + s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return( + &trillian.GetLatestSignedLogRootResponse{SignedLogRoot: slr}, nil, + ).Times(1) // Only called once during ensureStarted; no updater polling + + conn := dialMock(t, s.Addr) + frozenCfg := DefaultConfig() + frozenCfg.FrozenTreeIDs = map[int64]bool{700: true} + tc := newTrillianClient(trillian.NewTrillianLogClient(conn), 700, frozenCfg) + t.Cleanup(tc.Close) + + resp := tc.GetLatest(context.Background(), 0) + require.NoError(t, resp.Err) + require.Equal(t, codes.OK, resp.Status) + + var got types.LogRootV1 + require.NoError(t, got.UnmarshalBinary(resp.GetLatestResult.SignedLogRoot.LogRoot)) + require.EqualValues(t, 10, got.TreeSize) +} + +func TestFrozenClient_WaitForRootAtLeast_FailsImmediately(t *testing.T) { + frozenCfg := DefaultConfig() + frozenCfg.FrozenTreeIDs = map[int64]bool{701: true} + tc := newTrillianClient(nil, 701, frozenCfg) + tc.snapshot.Store(rootSnapshot{root: types.LogRootV1{TreeSize: 5}}) + t.Cleanup(tc.Close) + + // Request satisfied by current size + err := tc.waitForRootAtLeast(context.Background(), 5) + require.NoError(t, err) + + // Request above frozen size fails immediately + err = tc.waitForRootAtLeast(context.Background(), 10) + require.Error(t, err) + require.Equal(t, codes.FailedPrecondition, status.Code(err)) + require.Contains(t, err.Error(), "frozen") +} + +func TestClientManagerFactory_FrozenCachedClient(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + s, closeFn, err := testonly.NewMockServer(mockCtl) + require.NoError(t, err) + defer closeFn() + + cfg := Config{ + CacheSTH: true, + InitLatestRootTimeout: DefaultInitLatestRootTimeout, + UpdaterWaitTimeout: DefaultUpdaterWaitTimeout, + FrozenTreeIDs: map[int64]bool{42: true}, + } + cm := NewClientManager(nil, GRPCConfig{Address: s.Addr, Port: 0}, cfg) + + conn := dialMock(t, s.Addr) + cm.connMu.Lock() + cm.connections[cm.defaultConfig] = conn + cm.connMu.Unlock() + t.Cleanup(func() { _ = cm.Close() }) + + c, err := cm.GetTrillianClient(42) + require.NoError(t, err) + tc, ok := c.(*TrillianClient) + require.True(t, ok, "expected *TrillianClient when CacheSTH=true") + require.True(t, tc.frozen, "expected frozen=true for tree in frozenTreeIDs") + + // Non-frozen tree should not be frozen + c2, err := cm.GetTrillianClient(99) + require.NoError(t, err) + tc2, ok := c2.(*TrillianClient) + require.True(t, ok) + require.False(t, tc2.frozen, "expected frozen=false for tree not in frozenTreeIDs") +} diff --git a/tests/k6/rekor-load.js b/tests/k6/rekor-load.js new file mode 100644 index 000000000..5198233a4 --- /dev/null +++ b/tests/k6/rekor-load.js @@ -0,0 +1,304 @@ +/* + * Copyright 2026 The Sigstore Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import http from 'k6/http'; +import { check, sleep } from 'k6'; +import { Trend } from 'k6/metrics'; +import encoding from 'k6/encoding'; + +// --------------------------------------------------------------------------- +// Configuration via environment variables +// --------------------------------------------------------------------------- +const REKOR_URL = __ENV.REKOR_URL || 'http://localhost:3000'; +const WRITE_QPS = parseInt(__ENV.WRITE_QPS || '50', 10); +const WRITER_PRE_VUS = parseInt(__ENV.WRITER_PRE_VUS || '50', 10); +const WRITER_MAX_VUS = parseInt(__ENV.WRITER_MAX_VUS || '200', 10); +const TAILER_VUS = parseInt(__ENV.TAILER_VUS || '10', 10); +const DURATION = __ENV.DURATION || '1m'; +const TAILER_POLL_MS = parseInt(__ENV.TAILER_POLL_MS || '500', 10); + +// --------------------------------------------------------------------------- +// Custom metrics +// --------------------------------------------------------------------------- +const entryAddDuration = new Trend('rekor_entry_add_duration', true); +const entryReadDuration = new Trend('rekor_entry_read_duration', true); + +// --------------------------------------------------------------------------- +// k6 options +// --------------------------------------------------------------------------- +export const options = { + scenarios: { + writer: { + executor: 'constant-arrival-rate', + rate: WRITE_QPS, + timeUnit: '1s', + duration: DURATION, + preAllocatedVUs: WRITER_PRE_VUS, + maxVUs: WRITER_MAX_VUS, + exec: 'writer', + }, + tailer: { + executor: 'constant-vus', + vus: TAILER_VUS, + duration: DURATION, + exec: 'tailer', + startTime: '3s', + }, + }, + thresholds: { + 'http_req_duration{scenario:writer}': ['p(95)<5000'], + 'http_req_duration{scenario:tailer}': ['p(95)<1000'], + }, +}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function arrayBufToHex(buf) { + const bytes = new Uint8Array(buf); + let hex = ''; + for (let i = 0; i < bytes.length; i++) { + hex += bytes[i].toString(16).padStart(2, '0'); + } + return hex; +} + +function arrayBufToBase64(buf) { + return encoding.b64encode(new Uint8Array(buf), 'std'); +} + +function derToPem(derBuf, label) { + const b64 = arrayBufToBase64(derBuf); + let pem = `-----BEGIN ${label}-----\n`; + for (let i = 0; i < b64.length; i += 64) { + pem += b64.slice(i, i + 64) + '\n'; + } + pem += `-----END ${label}-----\n`; + return pem; +} + +// Convert IEEE P1363 signature (r||s) to ASN.1 DER SEQUENCE(INTEGER(r), INTEGER(s)). +// P-256 produces 64-byte P1363 signatures (32 bytes each for r and s). +function p1363ToDer(p1363Buf) { + const p1363 = new Uint8Array(p1363Buf); + const half = p1363.length / 2; + const r = p1363.slice(0, half); + const s = p1363.slice(half); + + function integerBytes(val) { + // Strip leading zeroes but keep at least one byte. + let start = 0; + while (start < val.length - 1 && val[start] === 0) { + start++; + } + // If the high bit is set, prepend a 0x00 so it's interpreted as positive. + const needsPad = val[start] >= 0x80; + const len = val.length - start + (needsPad ? 1 : 0); + const out = new Uint8Array(len); + if (needsPad) { + out[0] = 0x00; + out.set(val.slice(start), 1); + } else { + out.set(val.slice(start)); + } + return out; + } + + const rEnc = integerBytes(r); + const sEnc = integerBytes(s); + + // Each INTEGER: tag(0x02) + length + value + const seqLen = 2 + rEnc.length + 2 + sEnc.length; + + // Build DER: SEQUENCE tag(0x30) + length + INTEGER(r) + INTEGER(s) + const der = new Uint8Array(2 + seqLen); + let offset = 0; + der[offset++] = 0x30; // SEQUENCE tag + der[offset++] = seqLen; + der[offset++] = 0x02; // INTEGER tag + der[offset++] = rEnc.length; + der.set(rEnc, offset); + offset += rEnc.length; + der[offset++] = 0x02; // INTEGER tag + der[offset++] = sEnc.length; + der.set(sEnc, offset); + + return der; +} + +function base64Encode(str) { + return encoding.b64encode(str, 'std'); +} + +function stringToBytes(str) { + const bytes = new Uint8Array(str.length); + for (let i = 0; i < str.length; i++) { + bytes[i] = str.charCodeAt(i); + } + return bytes; +} + +// --------------------------------------------------------------------------- +// setup: generate an ECDSA P-256 key pair +// --------------------------------------------------------------------------- +export async function setup() { + const keyPair = await crypto.subtle.generateKey( + { name: 'ECDSA', namedCurve: 'P-256' }, + true, // extractable + ['sign', 'verify'], + ); + + // Export private key as PKCS8 DER → serialize as byte array for JSON transport + const privDer = await crypto.subtle.exportKey('pkcs8', keyPair.privateKey); + const privBytes = Array.from(new Uint8Array(privDer)); + + // Export public key as SPKI DER → PEM for the API + const pubDer = await crypto.subtle.exportKey('spki', keyPair.publicKey); + const pubPem = derToPem(pubDer, 'PUBLIC KEY'); + + return { privBytes, pubPem }; +} + +// --------------------------------------------------------------------------- +// Writer scenario +// --------------------------------------------------------------------------- + +// Module-level cache: each VU imports the private key once. +let _privKey = null; + +async function getPrivateKey(privBytes) { + if (_privKey) return _privKey; + const buf = new Uint8Array(privBytes).buffer; + _privKey = await crypto.subtle.importKey( + 'pkcs8', + buf, + { name: 'ECDSA', namedCurve: 'P-256' }, + false, + ['sign'], + ); + return _privKey; +} + +export async function writer(data) { + const privKey = await getPrivateKey(data.privBytes); + const pubPem = data.pubPem; + + // 1. Build a unique artifact + const artifact = `rekor-k6-load:vu=${__VU}:iter=${__ITER}:ts=${Date.now()}`; + const artifactBytes = stringToBytes(artifact); + + // 2. SHA-256 hash the artifact → hex + const hashBuf = await crypto.subtle.digest('SHA-256', artifactBytes); + const hashHex = arrayBufToHex(hashBuf); + + // 3. Sign the artifact bytes (ECDSA with SHA-256 — WebCrypto hashes internally) + const sigP1363 = await crypto.subtle.sign( + { name: 'ECDSA', hash: 'SHA-256' }, + privKey, + artifactBytes, + ); + + // 4. Convert P1363 → DER, then base64 + const sigDer = p1363ToDer(sigP1363); + const sigB64 = arrayBufToBase64(sigDer.buffer); + + // 5. Build the hashedrekord request body + const body = JSON.stringify({ + apiVersion: '0.0.1', + kind: 'hashedrekord', + spec: { + signature: { + content: sigB64, + publicKey: { + content: base64Encode(pubPem), + }, + }, + data: { + hash: { + algorithm: 'sha256', + value: hashHex, + }, + }, + }, + }); + + // 6. POST the entry + const res = http.post(`${REKOR_URL}/api/v1/log/entries`, body, { + headers: { 'Content-Type': 'application/json' }, + tags: { name: 'AddEntry' }, + }); + + entryAddDuration.add(res.timings.duration); + + check(res, { + 'writer: status is 201': (r) => r.status === 201, + }); +} + +// --------------------------------------------------------------------------- +// Tailer scenario +// --------------------------------------------------------------------------- + +// Per-VU state for tailing +let _lastReadIndex = -1; +let _baselineSet = false; + +export function tailer() { + // 1. GET log info + const logRes = http.get(`${REKOR_URL}/api/v1/log`, { + tags: { name: 'GetLogInfo' }, + }); + + const ok = check(logRes, { + 'tailer: log info 200': (r) => r.status === 200, + }); + + if (!ok) { + sleep(TAILER_POLL_MS / 1000); + return; + } + + const logInfo = logRes.json(); + const treeSize = parseInt(logInfo.treeSize, 10); + + // 2. Set baseline on first call + if (!_baselineSet) { + _lastReadIndex = treeSize - 1; + _baselineSet = true; + sleep(TAILER_POLL_MS / 1000); + return; + } + + // 3. Read each new entry by index + for (let i = _lastReadIndex + 1; i < treeSize; i++) { + const entryRes = http.get( + `${REKOR_URL}/api/v1/log/entries?logIndex=${i}`, + { tags: { name: 'GetEntryByIndex' } }, + ); + + entryReadDuration.add(entryRes.timings.duration); + + check(entryRes, { + 'tailer: entry read 200': (r) => r.status === 200, + }); + + _lastReadIndex = i; + } + + // 4. Brief pause between polls + sleep(TAILER_POLL_MS / 1000); +}