diff --git a/README.md b/README.md index 0b726bb..9f8102c 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,7 @@ proxy: max_stdin_message_size: 1MB replay_cache_ttl: 2m replay_cache_max_entries: 10000 + credential_cache_ttl: 0s max_output_size: 100MB max_connection_lifetime: 30m diff --git a/docs/CONFIG.md b/docs/CONFIG.md index b30d9f2..a0a0a5a 100644 --- a/docs/CONFIG.md +++ b/docs/CONFIG.md @@ -17,6 +17,7 @@ proxy: max_stdin_message_size: 1MB replay_cache_ttl: 2m replay_cache_max_entries: 10000 + credential_cache_ttl: 0s write_timeout: 30s max_output_size: 100MB max_connection_lifetime: 30m @@ -50,6 +51,7 @@ proxy: max_stdin_message_size: 1MB # Max wrapper->daemon NDJSON message replay_cache_ttl: 2m # Replay protection cache TTL replay_cache_max_entries: 10000 # Replay cache size bound + credential_cache_ttl: 0s # Credential cache TTL (0 disables) write_timeout: 30s # Write deadline per response message max_output_size: 100MB # Kill tool if output exceeds this max_connection_lifetime: 30m # Hard cap on connection duration @@ -127,6 +129,7 @@ proxy: max_stdin_message_size: 1MB # Max stdin/control message size replay_cache_ttl: 2m # Replay detection TTL replay_cache_max_entries: 10000 # Replay cache size cap + credential_cache_ttl: 0s # Credential fetch cache TTL (0 disables) write_timeout: 30s # Write deadline per response message max_output_size: 100MB # Kill tool if output exceeds this max_connection_lifetime: 30m # Hard cap on connection duration @@ -169,6 +172,17 @@ Controls replay protection for authenticated requests. Reuse of the same signed Note: the TTL has a floor of 10 seconds — values below 10s are clamped. +### `credential_cache_ttl` + +Optional in-memory TTL cache for credential fetch results. + +- Default: `0` (disabled) +- Format: Go duration (`30s`, `2m`, `1h`) +- Scope: only `op://` (1Password) and `bw:` (Bitwarden) credential sources +- `claw-wrap check` always bypasses this cache and fetches credentials live + +Use this to reduce repeated upstream secret-store latency for frequently-invoked tools. + ### `write_timeout` Deadline for each response message written back to the wrapper. Default: `30s`. Prevents slow/stalled clients from holding connections open indefinitely. diff --git a/docs/INSTALL.md b/docs/INSTALL.md index d7193ed..66b2d43 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -65,6 +65,7 @@ proxy: max_stdin_message_size: 1MB replay_cache_ttl: 2m replay_cache_max_entries: 10000 + credential_cache_ttl: 0s max_output_size: 100MB max_connection_lifetime: 30m @@ -137,6 +138,7 @@ claw-wrap list # Check credentials are accessible # Run from host/admin context (outside sandbox). # In strict firejail, this may fail by design. +# This check bypasses credential_cache_ttl and always fetches live. claw-wrap check # Test gh through claw-wrap diff --git a/internal/config/config.go b/internal/config/config.go index 80dee17..76fe62a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -46,7 +46,7 @@ type ProxyConfig struct { HMACSecretFile string `yaml:"hmac_secret_file"` // e.g., "/run/openclaw/auth" PassBinary string `yaml:"pass_binary"` // e.g., "/usr/bin/pass" OPBinary string `yaml:"op_binary"` // e.g., "/usr/local/bin/op" - OPTokenFile string `yaml:"op_token_file"` // e.g., "/etc/openclaw/1password.token" + OPTokenFile string `yaml:"op_token_file"` // e.g., "/etc/openclaw/1password.token" BWBinary string `yaml:"bw_binary"` // e.g., "/usr/local/bin/bw" AgeIdentityFile string `yaml:"age_identity_file"` // e.g., "/etc/openclaw/age-identity" MaxConnections int `yaml:"max_connections"` // e.g., 64 @@ -58,6 +58,7 @@ type ProxyConfig struct { MaxConnectionLifetime string `yaml:"max_connection_lifetime"` // e.g., "10m" (0 = unlimited) ReplayCacheTTL string `yaml:"replay_cache_ttl"` // e.g., "2m" ReplayCacheMax int `yaml:"replay_cache_max_entries"` + CredentialCacheTTL string `yaml:"credential_cache_ttl"` // e.g., "30s" (0/empty disables) } // SecurityConfig holds security policy flags. @@ -512,6 +513,19 @@ func (c *Config) GetReplayCacheMaxEntries() int { return c.Proxy.ReplayCacheMax } +// GetCredentialCacheTTL returns credential cache TTL. +// Returns 0 (disabled) when unset, invalid, or non-positive. +func (c *Config) GetCredentialCacheTTL() time.Duration { + if c.Proxy == nil || c.Proxy.CredentialCacheTTL == "" { + return 0 + } + d, err := ParseDuration(c.Proxy.CredentialCacheTTL) + if err != nil || d <= 0 { + return 0 + } + return d +} + // GetWriteTimeout returns the write deadline for daemon→client responses. func (c *Config) GetWriteTimeout() time.Duration { if c.Proxy == nil || c.Proxy.WriteTimeout == "" { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 9c95932..52b42b5 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -631,6 +631,9 @@ func TestGetProxySecurityDefaults(t *testing.T) { if got := cfg.GetReplayCacheMaxEntries(); got != DefaultReplayCacheMaxEntries { t.Errorf("GetReplayCacheMaxEntries() = %d, want %d", got, DefaultReplayCacheMaxEntries) } + if got := cfg.GetCredentialCacheTTL(); got != 0 { + t.Errorf("GetCredentialCacheTTL() = %v, want 0", got) + } } func TestGetReplayCacheTTL_Floor(t *testing.T) { @@ -659,6 +662,28 @@ func TestGetReplayCacheTTL_Floor(t *testing.T) { } } +func TestGetCredentialCacheTTL(t *testing.T) { + tests := []struct { + name string + cfg *Config + want time.Duration + }{ + {"nil proxy", &Config{}, 0}, + {"empty value", &Config{Proxy: &ProxyConfig{}}, 0}, + {"valid", &Config{Proxy: &ProxyConfig{CredentialCacheTTL: "30s"}}, 30 * time.Second}, + {"invalid", &Config{Proxy: &ProxyConfig{CredentialCacheTTL: "bad"}}, 0}, + {"zero", &Config{Proxy: &ProxyConfig{CredentialCacheTTL: "0s"}}, 0}, + {"negative", &Config{Proxy: &ProxyConfig{CredentialCacheTTL: "-5s"}}, 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.GetCredentialCacheTTL(); got != tt.want { + t.Errorf("GetCredentialCacheTTL() = %v, want %v", got, tt.want) + } + }) + } +} + func TestGetMaxOutputSize(t *testing.T) { tests := []struct { name string diff --git a/internal/credentials/cache.go b/internal/credentials/cache.go new file mode 100644 index 0000000..007d708 --- /dev/null +++ b/internal/credentials/cache.go @@ -0,0 +1,199 @@ +package credentials + +import ( + "sync" + "time" +) + +const ( + minCredentialCacheSweepInterval = 5 * time.Second + maxCredentialCacheSweepInterval = time.Minute +) + +type credentialCacheEntry struct { + value string + expiresAt time.Time +} + +type credentialCache struct { + mu sync.RWMutex + ttl time.Duration + entries map[string]credentialCacheEntry + sweeperStop chan struct{} + sweeperDone chan struct{} + sweeperInterval time.Duration +} + +var ( + credentialResultCache = newCredentialCache() + credentialCacheNow = time.Now + credentialTickerFactory = func(interval time.Duration) (<-chan time.Time, func()) { + ticker := time.NewTicker(interval) + return ticker.C, ticker.Stop + } +) + +func newCredentialCache() *credentialCache { + return &credentialCache{ + entries: make(map[string]credentialCacheEntry), + } +} + +// SetCredentialCacheTTL configures the in-memory credential cache TTL. +// Non-positive values disable the cache. +func SetCredentialCacheTTL(ttl time.Duration) { + credentialResultCache.SetTTL(ttl) +} + +func (c *credentialCache) SetTTL(ttl time.Duration) { + var stopCh chan struct{} + var doneCh chan struct{} + startSweeper := false + interval := time.Duration(0) + + c.mu.Lock() + switch { + case ttl <= 0: + c.ttl = 0 + c.entries = make(map[string]credentialCacheEntry) + stopCh, doneCh = c.detachSweeperLocked() + case ttl != c.ttl: + c.ttl = ttl + c.entries = make(map[string]credentialCacheEntry) + stopCh, doneCh = c.detachSweeperLocked() + startSweeper = true + interval = sweepInterval(ttl) + default: + // Same positive TTL should keep current sweeper unchanged. + if c.sweeperStop == nil { + startSweeper = true + interval = sweepInterval(ttl) + } + } + c.mu.Unlock() + + stopCredentialCacheSweeper(stopCh, doneCh) + if startSweeper { + c.startSweeper(interval) + } +} + +func (c *credentialCache) Get(key string, now time.Time) (string, bool) { + c.mu.RLock() + ttl := c.ttl + entry, ok := c.entries[key] + c.mu.RUnlock() + + if ttl <= 0 || !ok { + return "", false + } + if !now.Before(entry.expiresAt) { + c.mu.Lock() + if current, exists := c.entries[key]; exists && !now.Before(current.expiresAt) { + delete(c.entries, key) + } + c.mu.Unlock() + return "", false + } + + return entry.value, true +} + +func (c *credentialCache) Set(key, value string, now time.Time) { + if key == "" || value == "" { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + if c.ttl <= 0 { + return + } + c.sweepExpiredLocked(now) + c.entries[key] = credentialCacheEntry{ + value: value, + expiresAt: now.Add(c.ttl), + } +} + +func (c *credentialCache) sweepExpiredLocked(now time.Time) { + for key, entry := range c.entries { + if !now.Before(entry.expiresAt) { + delete(c.entries, key) + } + } +} + +func (c *credentialCache) detachSweeperLocked() (chan struct{}, chan struct{}) { + stopCh := c.sweeperStop + doneCh := c.sweeperDone + c.sweeperStop = nil + c.sweeperDone = nil + c.sweeperInterval = 0 + return stopCh, doneCh +} + +func stopCredentialCacheSweeper(stopCh, doneCh chan struct{}) { + if stopCh == nil || doneCh == nil { + return + } + close(stopCh) + <-doneCh +} + +func (c *credentialCache) startSweeper(interval time.Duration) { + if interval <= 0 { + return + } + + c.mu.Lock() + if c.ttl <= 0 || c.sweeperStop != nil { + c.mu.Unlock() + return + } + stopCh := make(chan struct{}) + doneCh := make(chan struct{}) + c.sweeperStop = stopCh + c.sweeperDone = doneCh + c.sweeperInterval = interval + c.mu.Unlock() + + tickCh, stopTicker := credentialTickerFactory(interval) + go func() { + defer close(doneCh) + defer stopTicker() + for { + select { + case <-stopCh: + return + case <-tickCh: + now := credentialCacheNow() + c.mu.Lock() + c.sweepExpiredLocked(now) + c.mu.Unlock() + } + } + }() +} + +func sweepInterval(ttl time.Duration) time.Duration { + if ttl <= 0 { + return 0 + } + interval := ttl / 2 + if interval < minCredentialCacheSweepInterval { + interval = minCredentialCacheSweepInterval + } + if interval > maxCredentialCacheSweepInterval { + interval = maxCredentialCacheSweepInterval + } + return interval +} + +func isCredentialCacheableBackend(backend Backend) bool { + return backend == Backend1Password || backend == BackendBitwarden +} + +func credentialCacheKey(parsed *ParsedSource) string { + return string(parsed.Backend) + "\x00" + parsed.Path + "\x00" + parsed.JQExpr +} diff --git a/internal/credentials/cache_test.go b/internal/credentials/cache_test.go new file mode 100644 index 0000000..16318de --- /dev/null +++ b/internal/credentials/cache_test.go @@ -0,0 +1,501 @@ +package credentials + +import ( + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +func writeCountingOPScript(t *testing.T, dir string) string { + t.Helper() + scriptPath := filepath.Join(dir, "op") + script := `#!/bin/sh +counter="$CW_COUNTER_FILE" +count=0 +if [ -f "$counter" ]; then + count=$(cat "$counter") +fi +count=$((count + 1)) +echo "$count" > "$counter" +echo "cached-secret" +` + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("write mock op script: %v", err) + } + return scriptPath +} + +func writeCountingPassScript(t *testing.T, dir string) string { + t.Helper() + scriptPath := filepath.Join(dir, "pass") + script := `#!/bin/sh +counter="$CW_COUNTER_FILE" +count=0 +if [ -f "$counter" ]; then + count=$(cat "$counter") +fi +count=$((count + 1)) +echo "$count" > "$counter" +shift +echo "secret-for-$1" +` + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("write mock pass script: %v", err) + } + return scriptPath +} + +func readCounterValue(t *testing.T, path string) int { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read counter file: %v", err) + } + n, err := strconv.Atoi(strings.TrimSpace(string(data))) + if err != nil { + t.Fatalf("parse counter value: %v", err) + } + return n +} + +func setupCredentialCacheTest(t *testing.T) func() { + t.Helper() + origNow := credentialCacheNow + origTickerFactory := credentialTickerFactory + origTTL := currentCredentialCacheTTL() + credentialCacheNow = time.Now + SetCredentialCacheTTL(0) + return func() { + SetCredentialCacheTTL(0) + credentialCacheNow = origNow + credentialTickerFactory = origTickerFactory + SetCredentialCacheTTL(origTTL) + } +} + +func currentCredentialCacheTTL() time.Duration { + credentialResultCache.mu.RLock() + defer credentialResultCache.mu.RUnlock() + return credentialResultCache.ttl +} + +func cacheEntryExists(key string) bool { + credentialResultCache.mu.RLock() + defer credentialResultCache.mu.RUnlock() + _, ok := credentialResultCache.entries[key] + return ok +} + +func cacheEntryCount() int { + credentialResultCache.mu.RLock() + defer credentialResultCache.mu.RUnlock() + return len(credentialResultCache.entries) +} + +func cacheSweeperState() (interval time.Duration, running bool) { + credentialResultCache.mu.RLock() + defer credentialResultCache.mu.RUnlock() + return credentialResultCache.sweeperInterval, credentialResultCache.sweeperStop != nil +} + +func waitUntil(t *testing.T, cond func() bool, msg string) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal(msg) +} + +type fakeTicker struct { + ch chan time.Time + mu sync.Mutex + stopped bool +} + +func newFakeTicker() *fakeTicker { + return &fakeTicker{ch: make(chan time.Time, 4)} +} + +func (t *fakeTicker) stop() { + t.mu.Lock() + t.stopped = true + t.mu.Unlock() +} + +func (t *fakeTicker) isStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + return t.stopped +} + +func (t *fakeTicker) tick(now time.Time) { + select { + case t.ch <- now: + default: + } +} + +type fakeTickerFactory struct { + mu sync.Mutex + tickers []*fakeTicker + intervals []time.Duration +} + +func (f *fakeTickerFactory) newTicker(interval time.Duration) (<-chan time.Time, func()) { + f.mu.Lock() + defer f.mu.Unlock() + ticker := newFakeTicker() + f.tickers = append(f.tickers, ticker) + f.intervals = append(f.intervals, interval) + return ticker.ch, ticker.stop +} + +func (f *fakeTickerFactory) count() int { + f.mu.Lock() + defer f.mu.Unlock() + return len(f.tickers) +} + +func (f *fakeTickerFactory) ticker(i int) *fakeTicker { + f.mu.Lock() + defer f.mu.Unlock() + return f.tickers[i] +} + +func (f *fakeTickerFactory) interval(i int) time.Duration { + f.mu.Lock() + defer f.mu.Unlock() + return f.intervals[i] +} + +func TestFetch_Caches1PasswordWithinTTL(t *testing.T) { + restore := setupCredentialCacheTest(t) + defer restore() + + tmpDir := t.TempDir() + counterPath := filepath.Join(tmpDir, "counter") + opPath := writeCountingOPScript(t, tmpDir) + t.Setenv("CW_COUNTER_FILE", counterPath) + t.Setenv("CREDENTIALS_DIRECTORY", "") + t.Setenv(opTokenEnvVar, "test-token") + + now := time.Unix(1_700_000_000, 0) + credentialCacheNow = func() time.Time { return now } + SetCredentialCacheTTL(5 * time.Minute) + + for i := 0; i < 2; i++ { + value, err := Fetch("op://Private/GitHub/token", WithOPBinary(opPath)) + if err != nil { + t.Fatalf("Fetch() error = %v", err) + } + if value != "cached-secret" { + t.Fatalf("Fetch() = %q, want %q", value, "cached-secret") + } + } + + if got := readCounterValue(t, counterPath); got != 1 { + t.Fatalf("op invocation count = %d, want 1", got) + } +} + +func TestFetch_CacheTTLExpiryRefetches1Password(t *testing.T) { + restore := setupCredentialCacheTest(t) + defer restore() + + tmpDir := t.TempDir() + counterPath := filepath.Join(tmpDir, "counter") + opPath := writeCountingOPScript(t, tmpDir) + t.Setenv("CW_COUNTER_FILE", counterPath) + t.Setenv("CREDENTIALS_DIRECTORY", "") + t.Setenv(opTokenEnvVar, "test-token") + + now := time.Unix(1_700_000_000, 0) + credentialCacheNow = func() time.Time { return now } + SetCredentialCacheTTL(30 * time.Second) + + if _, err := Fetch("op://Private/GitHub/token", WithOPBinary(opPath)); err != nil { + t.Fatalf("first Fetch() error = %v", err) + } + now = now.Add(31 * time.Second) + if _, err := Fetch("op://Private/GitHub/token", WithOPBinary(opPath)); err != nil { + t.Fatalf("second Fetch() error = %v", err) + } + + if got := readCounterValue(t, counterPath); got != 2 { + t.Fatalf("op invocation count = %d, want 2", got) + } +} + +func TestFetch_WithBypassCacheForcesLiveFetch(t *testing.T) { + restore := setupCredentialCacheTest(t) + defer restore() + + tmpDir := t.TempDir() + counterPath := filepath.Join(tmpDir, "counter") + opPath := writeCountingOPScript(t, tmpDir) + t.Setenv("CW_COUNTER_FILE", counterPath) + t.Setenv("CREDENTIALS_DIRECTORY", "") + t.Setenv(opTokenEnvVar, "test-token") + + now := time.Unix(1_700_000_000, 0) + credentialCacheNow = func() time.Time { return now } + SetCredentialCacheTTL(5 * time.Minute) + + if _, err := Fetch("op://Private/GitHub/token", WithOPBinary(opPath)); err != nil { + t.Fatalf("first Fetch() error = %v", err) + } + if _, err := Fetch("op://Private/GitHub/token", WithOPBinary(opPath), WithBypassCache()); err != nil { + t.Fatalf("second Fetch() error = %v", err) + } + + if got := readCounterValue(t, counterPath); got != 2 { + t.Fatalf("op invocation count = %d, want 2", got) + } +} + +func TestFetch_DoesNotCachePassBackend(t *testing.T) { + restore := setupCredentialCacheTest(t) + defer restore() + + tmpDir := t.TempDir() + counterPath := filepath.Join(tmpDir, "counter") + passPath := writeCountingPassScript(t, tmpDir) + t.Setenv("CW_COUNTER_FILE", counterPath) + + now := time.Unix(1_700_000_000, 0) + credentialCacheNow = func() time.Time { return now } + SetCredentialCacheTTL(5 * time.Minute) + + for i := 0; i < 2; i++ { + got, err := Fetch("pass:test/path", WithPassBinary(passPath)) + if err != nil { + t.Fatalf("Fetch() error = %v", err) + } + want := "secret-for-test/path" + if got != want { + t.Fatalf("Fetch() = %q, want %q", got, want) + } + } + + if got := readCounterValue(t, counterPath); got != 2 { + t.Fatalf("pass invocation count = %d, want 2", got) + } +} + +func TestSetCredentialCacheTTL_DisableClearsEntries(t *testing.T) { + restore := setupCredentialCacheTest(t) + defer restore() + + tmpDir := t.TempDir() + counterPath := filepath.Join(tmpDir, "counter") + opPath := writeCountingOPScript(t, tmpDir) + t.Setenv("CW_COUNTER_FILE", counterPath) + t.Setenv("CREDENTIALS_DIRECTORY", "") + t.Setenv(opTokenEnvVar, "test-token") + + now := time.Unix(1_700_000_000, 0) + credentialCacheNow = func() time.Time { return now } + + SetCredentialCacheTTL(5 * time.Minute) + if _, err := Fetch("op://Private/GitHub/token", WithOPBinary(opPath)); err != nil { + t.Fatalf("first Fetch() error = %v", err) + } + + SetCredentialCacheTTL(0) + SetCredentialCacheTTL(5 * time.Minute) + + if _, err := Fetch("op://Private/GitHub/token", WithOPBinary(opPath)); err != nil { + t.Fatalf("second Fetch() error = %v", err) + } + + if got := readCounterValue(t, counterPath); got != 2 { + t.Fatalf("op invocation count = %d, want 2", got) + } +} + +func TestSetCredentialCacheTTL_ChangingTTLFlushesEntries(t *testing.T) { + restore := setupCredentialCacheTest(t) + defer restore() + + tmpDir := t.TempDir() + counterPath := filepath.Join(tmpDir, "counter") + opPath := writeCountingOPScript(t, tmpDir) + t.Setenv("CW_COUNTER_FILE", counterPath) + t.Setenv("CREDENTIALS_DIRECTORY", "") + t.Setenv(opTokenEnvVar, "test-token") + + now := time.Unix(1_700_000_000, 0) + credentialCacheNow = func() time.Time { return now } + + SetCredentialCacheTTL(5 * time.Minute) + if _, err := Fetch("op://Private/GitHub/token", WithOPBinary(opPath)); err != nil { + t.Fatalf("first Fetch() error = %v", err) + } + + SetCredentialCacheTTL(1 * time.Minute) + if _, err := Fetch("op://Private/GitHub/token", WithOPBinary(opPath)); err != nil { + t.Fatalf("second Fetch() error = %v", err) + } + + if got := readCounterValue(t, counterPath); got != 2 { + t.Fatalf("op invocation count = %d, want 2", got) + } +} + +func TestCredentialCache_ActiveSweepEvictsExpiredWithoutRead(t *testing.T) { + restore := setupCredentialCacheTest(t) + defer restore() + + factory := &fakeTickerFactory{} + credentialTickerFactory = factory.newTicker + + now := time.Unix(1_700_000_000, 0) + credentialCacheNow = func() time.Time { return now } + + SetCredentialCacheTTL(10 * time.Second) + if got := factory.count(); got != 1 { + t.Fatalf("ticker count = %d, want 1", got) + } + if got := factory.interval(0); got != 5*time.Second { + t.Fatalf("sweeper interval = %v, want 5s", got) + } + + key := "op\x00op://Private/Item/field\x00" + credentialResultCache.Set(key, "secret", now) + if !cacheEntryExists(key) { + t.Fatal("expected cache entry to exist before sweep") + } + + now = now.Add(11 * time.Second) + factory.ticker(0).tick(now) + + waitUntil(t, func() bool { return !cacheEntryExists(key) }, "expected sweeper to evict expired entry") +} + +func TestCredentialCache_DisableClearsEntriesAndStopsSweeper(t *testing.T) { + restore := setupCredentialCacheTest(t) + defer restore() + + factory := &fakeTickerFactory{} + credentialTickerFactory = factory.newTicker + + now := time.Unix(1_700_000_000, 0) + credentialCacheNow = func() time.Time { return now } + + SetCredentialCacheTTL(30 * time.Second) + credentialResultCache.Set("k1", "v1", now) + if got := cacheEntryCount(); got != 1 { + t.Fatalf("cache entry count = %d, want 1", got) + } + + SetCredentialCacheTTL(0) + + if got := cacheEntryCount(); got != 0 { + t.Fatalf("cache entry count after disable = %d, want 0", got) + } + interval, running := cacheSweeperState() + if running { + t.Fatal("expected sweeper to be stopped") + } + if interval != 0 { + t.Fatalf("sweeper interval = %v, want 0", interval) + } + if !factory.ticker(0).isStopped() { + t.Fatal("expected ticker stop func to be called") + } +} + +func TestCredentialCache_ChangingTTLRestartsSweeper(t *testing.T) { + restore := setupCredentialCacheTest(t) + defer restore() + + factory := &fakeTickerFactory{} + credentialTickerFactory = factory.newTicker + + SetCredentialCacheTTL(20 * time.Second) + if got := factory.count(); got != 1 { + t.Fatalf("ticker count after first set = %d, want 1", got) + } + if got := factory.interval(0); got != 10*time.Second { + t.Fatalf("first interval = %v, want 10s", got) + } + + SetCredentialCacheTTL(40 * time.Second) + if got := factory.count(); got != 2 { + t.Fatalf("ticker count after ttl change = %d, want 2", got) + } + if !factory.ticker(0).isStopped() { + t.Fatal("expected first ticker to be stopped after ttl change") + } + if got := factory.interval(1); got != 20*time.Second { + t.Fatalf("second interval = %v, want 20s", got) + } + interval, running := cacheSweeperState() + if !running { + t.Fatal("expected sweeper to be running after ttl change") + } + if interval != 20*time.Second { + t.Fatalf("stored sweeper interval = %v, want 20s", interval) + } +} + +func TestCredentialCache_SameTTLDoesNotSpawnDuplicateSweeper(t *testing.T) { + restore := setupCredentialCacheTest(t) + defer restore() + + factory := &fakeTickerFactory{} + credentialTickerFactory = factory.newTicker + + SetCredentialCacheTTL(30 * time.Second) + SetCredentialCacheTTL(30 * time.Second) + + if got := factory.count(); got != 1 { + t.Fatalf("ticker count = %d, want 1", got) + } + if factory.ticker(0).isStopped() { + t.Fatal("did not expect ticker to stop when ttl is unchanged") + } +} + +func TestSweepInterval(t *testing.T) { + tests := []struct { + name string + ttl time.Duration + want time.Duration + }{ + {"disabled", 0, 0}, + {"short ttl clamps min", 6 * time.Second, 5 * time.Second}, + {"normal ttl half", 40 * time.Second, 20 * time.Second}, + {"long ttl clamps max", 10 * time.Minute, time.Minute}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := sweepInterval(tt.ttl); got != tt.want { + t.Fatalf("sweepInterval(%v) = %v, want %v", tt.ttl, got, tt.want) + } + }) + } +} + +func TestCredentialCacheKey_NormalizesBackendPathAndJQ(t *testing.T) { + parsed := &ParsedSource{ + Backend: Backend1Password, + Path: "op://Vault/Item/field", + JQExpr: ".foo", + } + key := credentialCacheKey(parsed) + want := fmt.Sprintf("%s\x00%s\x00%s", Backend1Password, "op://Vault/Item/field", ".foo") + if key != want { + t.Fatalf("credentialCacheKey() = %q, want %q", key, want) + } +} diff --git a/internal/credentials/credentials.go b/internal/credentials/credentials.go index 9092139..1b50982 100644 --- a/internal/credentials/credentials.go +++ b/internal/credentials/credentials.go @@ -22,9 +22,10 @@ var findTrustedBinaryFunc = paths.FindTrustedBinary // FetchOptions holds configuration for credential fetching. type FetchOptions struct { - PassBinary string - OPBinary string - BWBinary string + PassBinary string + OPBinary string + BWBinary string + BypassCache bool } // FetchOption configures credential fetching. @@ -51,6 +52,13 @@ func WithBWBinary(path string) FetchOption { } } +// WithBypassCache forces live credential fetches and bypasses result caching. +func WithBypassCache() FetchOption { + return func(o *FetchOptions) { + o.BypassCache = true + } +} + // Fetch retrieves a credential from the specified source. // Source formats: // - pass:path/in/store - fetch from password store @@ -77,43 +85,75 @@ func Fetch(source string, opts ...FetchOption) (string, error) { } ctx := context.Background() + cacheEligible := isCredentialCacheableBackend(parsed.Backend) && !options.BypassCache + cacheKey := "" + now := credentialCacheNow() + if cacheEligible { + cacheKey = credentialCacheKey(parsed) + if cached, ok := credentialResultCache.Get(cacheKey, now); ok { + return cached, nil + } + } + var result string switch parsed.Backend { case BackendEnv: - result, err := fetchFromEnvFile(parsed.Path) + result, err = fetchFromEnvFile(parsed.Path) if err != nil { return "", err } if parsed.HasJQ() { - return ApplyJQ(ctx, []byte(result), parsed.JQExpr) + result, err = ApplyJQ(ctx, []byte(result), parsed.JQExpr) + if err != nil { + return "", err + } } - return result, nil case BackendPass: - result, err := fetchFromPass(options.PassBinary, parsed.Path) + result, err = fetchFromPass(options.PassBinary, parsed.Path) if err != nil { return "", err } if parsed.HasJQ() { - return ApplyJQ(ctx, []byte(result), parsed.JQExpr) + result, err = ApplyJQ(ctx, []byte(result), parsed.JQExpr) + if err != nil { + return "", err + } } - return result, nil case Backend1Password: - return fetchFrom1Password(ctx, parsed, options.OPBinary) + result, err = fetchFrom1Password(ctx, parsed, options.OPBinary) + if err != nil { + return "", err + } case BackendAge: - return fetchFromAge(ctx, parsed) + result, err = fetchFromAge(ctx, parsed) + if err != nil { + return "", err + } case BackendKeychain: - return fetchFromKeychain(ctx, parsed) + result, err = fetchFromKeychain(ctx, parsed) + if err != nil { + return "", err + } case BackendBitwarden: - return fetchFromBitwarden(ctx, parsed, options.BWBinary) + result, err = fetchFromBitwarden(ctx, parsed, options.BWBinary) + if err != nil { + return "", err + } default: return "", fmt.Errorf("unknown credential backend") } + + if cacheEligible && result != "" { + insertNow := credentialCacheNow() + credentialResultCache.Set(cacheKey, result, insertNow) + } + return result, nil } // fetchFromEnvFile reads a credential from the env file. diff --git a/internal/credentials/credentials_test.go b/internal/credentials/credentials_test.go index 56eca84..0e63950 100644 --- a/internal/credentials/credentials_test.go +++ b/internal/credentials/credentials_test.go @@ -59,6 +59,14 @@ func TestWithBWBinary(t *testing.T) { } } +func TestWithBypassCache(t *testing.T) { + opts := &FetchOptions{} + WithBypassCache()(opts) + if !opts.BypassCache { + t.Error("BypassCache = false, want true") + } +} + func TestFetch_DefaultPassBinary(t *testing.T) { // Verify the default is /usr/bin/pass when no option is provided. options := &FetchOptions{PassBinary: "/usr/bin/pass"} diff --git a/internal/daemon/admin_check_cache_test.go b/internal/daemon/admin_check_cache_test.go new file mode 100644 index 0000000..898bbac --- /dev/null +++ b/internal/daemon/admin_check_cache_test.go @@ -0,0 +1,125 @@ +package daemon + +import ( + "encoding/json" + "net" + "os" + "strconv" + "testing" + "time" + + "claw-wrap/internal/auth" + "claw-wrap/internal/config" + "claw-wrap/internal/credentials" + "claw-wrap/internal/protocol" +) + +func TestHandleAdminRequest_CheckBypassesCredentialCache(t *testing.T) { + origFetch := fetchCredentialFunc + origExe := resolvePeerExecutableFunc + origArgv0 := resolvePeerArgv0Func + defer func() { + fetchCredentialFunc = origFetch + resolvePeerExecutableFunc = origExe + resolvePeerArgv0Func = origArgv0 + }() + + resolvePeerExecutableFunc = func(pid int32) (string, error) { + return "/usr/local/bin/claw-wrap", nil + } + resolvePeerArgv0Func = func(pid int32) (string, error) { + return "/usr/local/bin/claw-wrap", nil + } + + fetchCalled := false + bypassCache := false + fetchCredentialFunc = func(source string, opts ...credentials.FetchOption) (string, error) { + fetchCalled = true + if source != "op://Private/GitHub/token" { + t.Fatalf("source = %q, want %q", source, "op://Private/GitHub/token") + } + fetchOpts := &credentials.FetchOptions{} + for _, opt := range opts { + opt(fetchOpts) + } + bypassCache = fetchOpts.BypassCache + return "secret-value", nil + } + + d := New(WithAllowedBinaries([]string{"/usr/local/bin/claw-wrap"})) + d.secret = []byte("0123456789abcdef0123456789abcdef") + d.replayCache = auth.NewReplayCache(2*time.Minute, 1000) + + cfg := &config.Config{ + Credentials: map[string]config.CredentialDef{ + "github-token": {Source: "op://Private/GitHub/token"}, + }, + Tools: map[string]config.ToolDef{}, + } + d.cfg = cfg + + timestamp := strconv.FormatInt(time.Now().Unix(), 10) + nonce, err := auth.GenerateNonce() + if err != nil { + t.Fatalf("GenerateNonce() error = %v", err) + } + hmac, err := auth.ComputeHMAC(d.secret, timestamp, "admin:check", "", nil, nonce) + if err != nil { + t.Fatalf("ComputeHMAC() error = %v", err) + } + + req := protocol.AdminRequest{ + Version: protocol.ProtocolVersion, + Admin: "check", + Timestamp: timestamp, + Nonce: nonce, + HMAC: hmac, + } + payload, err := json.Marshal(req) + if err != nil { + t.Fatalf("marshal request: %v", err) + } + + server, client := net.Pipe() + done := make(chan struct{}) + go func() { + defer close(done) + defer server.Close() + d.handleAdminRequest(server, payload, cfg, uint32(os.Getuid()), 1234) + }() + + buf := make([]byte, 64*1024) + n, err := client.Read(buf) + if err != nil { + t.Fatalf("read response: %v", err) + } + _ = client.Close() + <-done + + var adminErr map[string]string + if err := json.Unmarshal(buf[:n], &adminErr); err == nil { + if msg := adminErr["error"]; msg != "" { + t.Fatalf("admin error response: %s", msg) + } + } + + var resp protocol.AdminCheckResponse + if err := json.Unmarshal(buf[:n], &resp); err != nil { + t.Fatalf("unmarshal admin check response: %v", err) + } + + if !fetchCalled { + t.Fatal("fetchCredentialFunc was not called") + } + if !bypassCache { + t.Fatal("expected admin check to call Fetch with WithBypassCache()") + } + + info, ok := resp.Credentials["github-token"] + if !ok { + t.Fatal("response missing github-token credential status") + } + if info.Status != "ok" { + t.Fatalf("credential status = %q, want %q", info.Status, "ok") + } +} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index dc3c158..e20ddcb 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -55,6 +55,8 @@ var ( resolvePeerArgv0Func = resolvePeerArgv0 setAgeIdentityFileFunc = credentials.SetAgeIdentityFile setOPTokenFileFunc = credentials.SetOPTokenFile + setCredentialCacheTTLFunc = credentials.SetCredentialCacheTTL + fetchCredentialFunc = credentials.Fetch cleanupBWSessionFunc = credentials.CleanupBWSession ) @@ -118,6 +120,7 @@ func (d *Daemon) Run() error { log.Printf("[INFO] Loaded %d credentials from config", len(cfg.Credentials)) setAgeIdentityFileFunc(cfg.GetAgeIdentityFile()) setOPTokenFileFunc(cfg.GetOPTokenFile()) + setCredentialCacheTTLFunc(cfg.GetCredentialCacheTTL()) defer cleanupBWSessionFunc() secret, err := auth.GenerateSecret() @@ -237,6 +240,7 @@ func (d *Daemon) reloadConfig() error { // Configure credential backends setAgeIdentityFileFunc(newCfg.GetAgeIdentityFile()) setOPTokenFileFunc(newCfg.GetOPTokenFile()) + setCredentialCacheTTLFunc(newCfg.GetCredentialCacheTTL()) d.cfgMu.Lock() d.cfg = newCfg @@ -414,11 +418,12 @@ func (d *Daemon) handleAdminRequest(conn net.Conn, data []byte, cfg *config.Conf case "check": resp := protocol.AdminCheckResponse{Credentials: make(map[string]protocol.CredentialInfo), Version: d.version} for name, credDef := range cfg.Credentials { - value, err := credentials.Fetch( + value, err := fetchCredentialFunc( credDef.Source, credentials.WithPassBinary(cfg.GetPassBinary()), credentials.WithOPBinary(cfg.GetOPBinary()), credentials.WithBWBinary(cfg.GetBWBinary()), + credentials.WithBypassCache(), ) if err != nil || value == "" { if err != nil { diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index ee0b40a..e94cb7f 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" "testing" + "time" ) func TestDaemon_ReloadConfig(t *testing.T) { @@ -167,6 +168,7 @@ func TestDaemon_ReloadConfig_SetsBackendTokenPaths(t *testing.T) { proxy: age_identity_file: /tmp/age-a op_token_file: /tmp/op-a + credential_cache_ttl: 20s tools: {} ` if err := os.WriteFile(configPath, []byte(initialConfig), 0o644); err != nil { @@ -175,19 +177,25 @@ tools: {} origSetAge := setAgeIdentityFileFunc origSetOP := setOPTokenFileFunc + origSetCacheTTL := setCredentialCacheTTLFunc defer func() { setAgeIdentityFileFunc = origSetAge setOPTokenFileFunc = origSetOP + setCredentialCacheTTLFunc = origSetCacheTTL }() var seenAge []string var seenOP []string + var seenCacheTTL []time.Duration setAgeIdentityFileFunc = func(path string) { seenAge = append(seenAge, path) } setOPTokenFileFunc = func(path string) { seenOP = append(seenOP, path) } + setCredentialCacheTTLFunc = func(ttl time.Duration) { + seenCacheTTL = append(seenCacheTTL, ttl) + } d := New(WithConfigPath(configPath)) if err := d.reloadConfig(); err != nil { @@ -198,6 +206,7 @@ tools: {} proxy: age_identity_file: /tmp/age-b op_token_file: /tmp/op-b + credential_cache_ttl: 45s tools: {} ` if err := os.WriteFile(configPath, []byte(updatedConfig), 0o644); err != nil { @@ -219,6 +228,12 @@ tools: {} if seenOP[0] != "/tmp/op-a" || seenOP[1] != "/tmp/op-b" { t.Fatalf("setOPTokenFileFunc calls = %v, want [/tmp/op-a /tmp/op-b]", seenOP) } + if len(seenCacheTTL) != 2 { + t.Fatalf("setCredentialCacheTTLFunc called %d times, want 2", len(seenCacheTTL)) + } + if seenCacheTTL[0] != 20*time.Second || seenCacheTTL[1] != 45*time.Second { + t.Fatalf("setCredentialCacheTTLFunc calls = %v, want [20s 45s]", seenCacheTTL) + } } func TestDaemon_Run_ConfiguresBackendsAndCleansUpOnFailure(t *testing.T) { @@ -228,6 +243,7 @@ func TestDaemon_Run_ConfiguresBackendsAndCleansUpOnFailure(t *testing.T) { proxy: age_identity_file: /tmp/startup-age op_token_file: /tmp/startup-op + credential_cache_ttl: 30s hmac_secret_file: /this/path/does/not/exist/auth tools: {} ` @@ -237,10 +253,12 @@ tools: {} origSetAge := setAgeIdentityFileFunc origSetOP := setOPTokenFileFunc + origSetCacheTTL := setCredentialCacheTTLFunc origCleanup := cleanupBWSessionFunc defer func() { setAgeIdentityFileFunc = origSetAge setOPTokenFileFunc = origSetOP + setCredentialCacheTTLFunc = origSetCacheTTL cleanupBWSessionFunc = origCleanup }() @@ -252,6 +270,10 @@ tools: {} setOPTokenFileFunc = func(path string) { setOPPath = path } + setCacheTTL := time.Duration(0) + setCredentialCacheTTLFunc = func(ttl time.Duration) { + setCacheTTL = ttl + } cleanupCalls := 0 cleanupBWSessionFunc = func() { @@ -273,6 +295,9 @@ tools: {} if setOPPath != "/tmp/startup-op" { t.Fatalf("setOPTokenFileFunc path = %q, want %q", setOPPath, "/tmp/startup-op") } + if setCacheTTL != 30*time.Second { + t.Fatalf("setCredentialCacheTTLFunc ttl = %v, want 30s", setCacheTTL) + } if cleanupCalls != 1 { t.Fatalf("cleanupBWSessionFunc calls = %d, want 1", cleanupCalls) }